diff options
Diffstat (limited to '')
-rw-r--r-- | tests/utils.py | 130 |
1 files changed, 100 insertions, 30 deletions
diff --git a/tests/utils.py b/tests/utils.py index 32636b7..5636a13 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,55 +1,125 @@ -import contextlib -import http.server -import pathlib -import threading +from __future__ import annotations + +__all__ = ('http_server',) + +import socket +from contextlib import contextmanager +from http.server import ThreadingHTTPServer +from pathlib import Path from ssl import PROTOCOL_TLS_SERVER, SSLContext +from threading import Thread +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +if TYPE_CHECKING: + from collections.abc import Iterator + from http.server import HTTPServer + from socketserver import BaseRequestHandler + from typing import Final -import filelock + from sphinx.application import Sphinx # Generated with: # $ openssl req -new -x509 -days 3650 -nodes -out cert.pem \ # -keyout cert.pem -addext "subjectAltName = DNS:localhost" -TESTS_ROOT = pathlib.Path(__file__).parent -CERT_FILE = str(TESTS_ROOT / "certs" / "cert.pem") +TESTS_ROOT: Final[Path] = Path(__file__).parent +CERT_FILE: Final[str] = str(TESTS_ROOT / 'certs' / 'cert.pem') -# File lock for tests -LOCK_PATH = str(TESTS_ROOT / 'test-server.lock') +class HttpServerThread(Thread): + def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None: + """ + Constructs a threaded HTTP server. The default port number of ``0`` + delegates selection of a port number to bind to to Python. -class HttpServerThread(threading.Thread): - def __init__(self, handler, *args, **kwargs): - super().__init__(*args, **kwargs) - self.server = http.server.ThreadingHTTPServer(("localhost", 7777), handler) + Ref: https://docs.python.org/3.11/library/socketserver.html#asynchronous-mixins + """ + super().__init__(daemon=True) + self.server = ThreadingHTTPServer(('localhost', port), handler) - def run(self): + def run(self) -> None: self.server.serve_forever(poll_interval=0.001) - def terminate(self): + def terminate(self) -> None: self.server.shutdown() self.server.server_close() self.join() class HttpsServerThread(HttpServerThread): - def __init__(self, handler, *args, **kwargs): - super().__init__(handler, *args, **kwargs) + def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None: + super().__init__(handler, port=port) sslcontext = SSLContext(PROTOCOL_TLS_SERVER) sslcontext.load_cert_chain(CERT_FILE) self.server.socket = sslcontext.wrap_socket(self.server.socket, server_side=True) -def create_server(thread_class): - def server(handler): - lock = filelock.FileLock(LOCK_PATH) - with lock: - server_thread = thread_class(handler, daemon=True) - server_thread.start() - try: - yield server_thread - finally: - server_thread.terminate() - return contextlib.contextmanager(server) +@contextmanager +def http_server( + handler: type[BaseRequestHandler], + *, + tls_enabled: bool = False, + port: int = 0, +) -> Iterator[HTTPServer]: + server_cls = HttpsServerThread if tls_enabled else HttpServerThread + server_thread = server_cls(handler, port=port) + server_thread.start() + server_port = server_thread.server.server_port + assert port == 0 or server_port == port + try: + socket.create_connection(('localhost', server_port), timeout=0.5).close() + yield server_thread.server # Connection has been confirmed possible; proceed. + finally: + server_thread.terminate() + + +@contextmanager +def rewrite_hyperlinks(app: Sphinx, server: HTTPServer) -> Iterator[None]: + """ + Rewrite hyperlinks that refer to network location 'localhost:7777', + allowing that location to vary dynamically with the arbitrary test HTTP + server port assigned during unit testing. + + :param app: The Sphinx application where link replacement is to occur. + :param server: Destination server to redirect the hyperlinks to. + """ + match_netloc, replacement_netloc = ( + 'localhost:7777', + f'localhost:{server.server_port}', + ) + + def rewrite_hyperlink(_app: Sphinx, uri: str) -> str | None: + parsed_uri = urlparse(uri) + if parsed_uri.netloc != match_netloc: + return uri + return parsed_uri._replace(netloc=replacement_netloc).geturl() + + listener_id = app.connect('linkcheck-process-uri', rewrite_hyperlink) + yield + app.disconnect(listener_id) + + +@contextmanager +def serve_application( + app: Sphinx, + handler: type[BaseRequestHandler], + *, + tls_enabled: bool = False, + port: int = 0, +) -> Iterator[str]: + """ + Prepare a temporary server to handle HTTP requests related to the links + found in a Sphinx application project. + :param app: The Sphinx application. + :param handler: Determines how each request will be handled. + :param tls_enabled: Whether TLS (SSL) should be enabled for the server. + :param port: Optional server port (default: auto). -http_server = create_server(HttpServerThread) -https_server = create_server(HttpsServerThread) + :return: The address of the temporary HTTP server. + """ + with ( + http_server(handler, tls_enabled=tls_enabled, port=port) as server, + rewrite_hyperlinks(app, server), + ): + yield f'localhost:{server.server_port}' |