diff options
Diffstat (limited to '')
-rw-r--r-- | tests/fix_db.py | 358 |
1 files changed, 358 insertions, 0 deletions
diff --git a/tests/fix_db.py b/tests/fix_db.py new file mode 100644 index 0000000..3a37aa1 --- /dev/null +++ b/tests/fix_db.py @@ -0,0 +1,358 @@ +import io +import os +import sys +import pytest +import logging +from contextlib import contextmanager +from typing import Optional + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg._compat import cache +from psycopg.pq._debug import PGconnDebug + +from .utils import check_postgres_version + +# Set by warm_up_database() the first time the dsn fixture is used +pg_version: int +crdb_version: Optional[int] + + +def pytest_addoption(parser): + parser.addoption( + "--test-dsn", + metavar="DSN", + default=os.environ.get("PSYCOPG_TEST_DSN"), + help=( + "Connection string to run database tests requiring a connection" + " [you can also use the PSYCOPG_TEST_DSN env var]." + ), + ) + parser.addoption( + "--pq-trace", + metavar="{TRACEFILE,STDERR}", + default=None, + help="Generate a libpq trace to TRACEFILE or STDERR.", + ) + parser.addoption( + "--pq-debug", + action="store_true", + default=False, + help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)", + ) + + +def pytest_report_header(config): + dsn = config.getoption("--test-dsn") + if dsn is None: + return [] + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + server_version = conn.execute("select version()").fetchall()[0][0] + except Exception as ex: + server_version = f"unknown ({ex})" + + return [ + f"Server version: {server_version}", + ] + + +def pytest_collection_modifyitems(items): + for item in items: + for name in item.fixturenames: + if name in ("pipeline", "apipeline"): + item.add_marker(pytest.mark.pipeline) + break + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="pipeline"): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + + +def pytest_configure(config): + # register pg marker + markers = [ + "pg(version_expr): run the test only with matching server version" + " (e.g. '>= 10', '< 9.6')", + "pipeline: the test runs with connection in pipeline mode", + ] + for marker in markers: + config.addinivalue_line("markers", marker) + + +@pytest.fixture(scope="session") +def session_dsn(request): + """ + Return the dsn used to connect to the `--test-dsn` database (session-wide). + """ + dsn = request.config.getoption("--test-dsn") + if dsn is None: + pytest.skip("skipping test as no --test-dsn") + + warm_up_database(dsn) + return dsn + + +@pytest.fixture +def dsn(session_dsn, request): + """Return the dsn used to connect to the `--test-dsn` database.""" + check_connection_version(request.node) + return session_dsn + + +@pytest.fixture(scope="session") +def tracefile(request): + """Open and yield a file for libpq client/server communication traces if + --pq-tracefile option is set. + """ + tracefile = request.config.getoption("--pq-trace") + if not tracefile: + yield None + return + + if tracefile.lower() == "stderr": + try: + sys.stderr.fileno() + except io.UnsupportedOperation: + raise pytest.UsageError( + "cannot use stderr for --pq-trace (in-memory file?)" + ) from None + + yield sys.stderr + return + + with open(tracefile, "w") as f: + yield f + + +@contextmanager +def maybe_trace(pgconn, tracefile, function): + """Handle libpq client/server communication traces for a single test + function. + """ + if tracefile is None: + yield None + return + + if tracefile != sys.stderr: + title = f" {function.__module__}::{function.__qualname__} ".center(80, "=") + tracefile.write(title + "\n") + tracefile.flush() + + pgconn.trace(tracefile.fileno()) + try: + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + except psycopg.NotSupportedError: + pass + try: + yield None + finally: + pgconn.untrace() + + +@pytest.fixture(autouse=True) +def pgconn_debug(request): + if not request.config.getoption("--pq-debug"): + return + if pq.__impl__ != "python": + raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug") + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + pq.PGconn = PGconnDebug + + +@pytest.fixture +def pgconn(dsn, request, tracefile): + """Return a PGconn connection open to `--test-dsn`.""" + check_connection_version(request.node) + + conn = pq.PGconn.connect(dsn.encode()) + if conn.status != pq.ConnStatus.OK: + pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") + + with maybe_trace(conn, tracefile, request.function): + yield conn + + conn.finish() + + +@pytest.fixture +def conn(conn_cls, dsn, request, tracefile): + """Return a `Connection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = conn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +def pipeline(request, conn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + with conn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture +async def aconn(dsn, aconn_cls, request, tracefile): + """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = await aconn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + await conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +async def apipeline(request, aconn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + async with aconn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture(scope="session") +def conn_cls(session_dsn): + cls = psycopg.Connection + if crdb_version: + from psycopg.crdb import CrdbConnection + + cls = CrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def aconn_cls(session_dsn): + cls = psycopg.AsyncConnection + if crdb_version: + from psycopg.crdb import AsyncCrdbConnection + + cls = AsyncCrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def svcconn(conn_cls, session_dsn): + """ + Return a session `Connection` connected to the ``--test-dsn`` database. + """ + conn = conn_cls.connect(session_dsn, autocommit=True) + yield conn + conn.close() + + +@pytest.fixture +def commands(conn, monkeypatch): + """The list of commands issued internally by the test connection.""" + yield patch_exec(conn, monkeypatch) + + +@pytest.fixture +def acommands(aconn, monkeypatch): + """The list of commands issued internally by the test async connection.""" + yield patch_exec(aconn, monkeypatch) + + +def patch_exec(conn, monkeypatch): + """Helper to implement the commands fixture both sync and async.""" + _orig_exec_command = conn._exec_command + L = ListPopAll() + + def _exec_command(command, *args, **kwargs): + cmdcopy = command + if isinstance(cmdcopy, bytes): + cmdcopy = cmdcopy.decode(conn.info.encoding) + elif isinstance(cmdcopy, sql.Composable): + cmdcopy = cmdcopy.as_string(conn) + + L.append(cmdcopy) + return _orig_exec_command(command, *args, **kwargs) + + monkeypatch.setattr(conn, "_exec_command", _exec_command) + return L + + +class ListPopAll(list): # type: ignore[type-arg] + """A list, with a popall() method.""" + + def popall(self): + out = self[:] + del self[:] + return out + + +def check_connection_version(node): + try: + pg_version + except NameError: + # First connection creation failed. Let the tests fail. + pytest.fail("server version not available") + + for mark in node.iter_markers(): + if mark.name == "pg": + assert len(mark.args) == 1 + msg = check_postgres_version(pg_version, mark.args[0]) + if msg: + pytest.skip(msg) + + elif mark.name in ("crdb", "crdb_skip"): + from .fix_crdb import check_crdb_version + + msg = check_crdb_version(crdb_version, mark) + if msg: + pytest.skip(msg) + + +@pytest.fixture +def hstore(svcconn): + try: + with svcconn.transaction(): + svcconn.execute("create extension if not exists hstore") + except psycopg.Error as e: + pytest.skip(str(e)) + + +@cache +def warm_up_database(dsn: str) -> None: + """Connect to the database before returning a connection. + + In the CI sometimes, the first test fails with a timeout, probably because + the server hasn't started yet. Absorb the delay before the test. + + In case of error, abort the test run entirely, to avoid failing downstream + hundreds of times. + """ + global pg_version, crdb_version + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + conn.execute("select 1") + + pg_version = conn.info.server_version + + crdb_version = None + param = conn.info.parameter_status("crdb_version") + if param: + from psycopg.crdb import CrdbConnectionInfo + + crdb_version = CrdbConnectionInfo.parse_crdb_version(param) + except Exception as exc: + pytest.exit(f"failed to connect to the test database: {exc}") |