summaryrefslogtreecommitdiffstats
path: root/tests/fix_db.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/fix_db.py358
1 files changed, 358 insertions, 0 deletions
diff --git a/tests/fix_db.py b/tests/fix_db.py
new file mode 100644
index 0000000..3a37aa1
--- /dev/null
+++ b/tests/fix_db.py
@@ -0,0 +1,358 @@
+import io
+import os
+import sys
+import pytest
+import logging
+from contextlib import contextmanager
+from typing import Optional
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg._compat import cache
+from psycopg.pq._debug import PGconnDebug
+
+from .utils import check_postgres_version
+
+# Set by warm_up_database() the first time the dsn fixture is used
+pg_version: int
+crdb_version: Optional[int]
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--test-dsn",
+ metavar="DSN",
+ default=os.environ.get("PSYCOPG_TEST_DSN"),
+ help=(
+ "Connection string to run database tests requiring a connection"
+ " [you can also use the PSYCOPG_TEST_DSN env var]."
+ ),
+ )
+ parser.addoption(
+ "--pq-trace",
+ metavar="{TRACEFILE,STDERR}",
+ default=None,
+ help="Generate a libpq trace to TRACEFILE or STDERR.",
+ )
+ parser.addoption(
+ "--pq-debug",
+ action="store_true",
+ default=False,
+ help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)",
+ )
+
+
+def pytest_report_header(config):
+ dsn = config.getoption("--test-dsn")
+ if dsn is None:
+ return []
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ server_version = conn.execute("select version()").fetchall()[0][0]
+ except Exception as ex:
+ server_version = f"unknown ({ex})"
+
+ return [
+ f"Server version: {server_version}",
+ ]
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ for name in item.fixturenames:
+ if name in ("pipeline", "apipeline"):
+ item.add_marker(pytest.mark.pipeline)
+ break
+
+
+def pytest_runtest_setup(item):
+ for m in item.iter_markers(name="pipeline"):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+
+
+def pytest_configure(config):
+ # register pg marker
+ markers = [
+ "pg(version_expr): run the test only with matching server version"
+ " (e.g. '>= 10', '< 9.6')",
+ "pipeline: the test runs with connection in pipeline mode",
+ ]
+ for marker in markers:
+ config.addinivalue_line("markers", marker)
+
+
+@pytest.fixture(scope="session")
+def session_dsn(request):
+ """
+ Return the dsn used to connect to the `--test-dsn` database (session-wide).
+ """
+ dsn = request.config.getoption("--test-dsn")
+ if dsn is None:
+ pytest.skip("skipping test as no --test-dsn")
+
+ warm_up_database(dsn)
+ return dsn
+
+
+@pytest.fixture
+def dsn(session_dsn, request):
+ """Return the dsn used to connect to the `--test-dsn` database."""
+ check_connection_version(request.node)
+ return session_dsn
+
+
+@pytest.fixture(scope="session")
+def tracefile(request):
+ """Open and yield a file for libpq client/server communication traces if
+ --pq-tracefile option is set.
+ """
+ tracefile = request.config.getoption("--pq-trace")
+ if not tracefile:
+ yield None
+ return
+
+ if tracefile.lower() == "stderr":
+ try:
+ sys.stderr.fileno()
+ except io.UnsupportedOperation:
+ raise pytest.UsageError(
+ "cannot use stderr for --pq-trace (in-memory file?)"
+ ) from None
+
+ yield sys.stderr
+ return
+
+ with open(tracefile, "w") as f:
+ yield f
+
+
+@contextmanager
+def maybe_trace(pgconn, tracefile, function):
+ """Handle libpq client/server communication traces for a single test
+ function.
+ """
+ if tracefile is None:
+ yield None
+ return
+
+ if tracefile != sys.stderr:
+ title = f" {function.__module__}::{function.__qualname__} ".center(80, "=")
+ tracefile.write(title + "\n")
+ tracefile.flush()
+
+ pgconn.trace(tracefile.fileno())
+ try:
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ except psycopg.NotSupportedError:
+ pass
+ try:
+ yield None
+ finally:
+ pgconn.untrace()
+
+
+@pytest.fixture(autouse=True)
+def pgconn_debug(request):
+ if not request.config.getoption("--pq-debug"):
+ return
+ if pq.__impl__ != "python":
+ raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug")
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+ pq.PGconn = PGconnDebug
+
+
+@pytest.fixture
+def pgconn(dsn, request, tracefile):
+ """Return a PGconn connection open to `--test-dsn`."""
+ check_connection_version(request.node)
+
+ conn = pq.PGconn.connect(dsn.encode())
+ if conn.status != pq.ConnStatus.OK:
+ pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
+
+ with maybe_trace(conn, tracefile, request.function):
+ yield conn
+
+ conn.finish()
+
+
+@pytest.fixture
+def conn(conn_cls, dsn, request, tracefile):
+ """Return a `Connection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = conn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+def pipeline(request, conn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ with conn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture
+async def aconn(dsn, aconn_cls, request, tracefile):
+ """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = await aconn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ await conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+async def apipeline(request, aconn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ async with aconn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture(scope="session")
+def conn_cls(session_dsn):
+ cls = psycopg.Connection
+ if crdb_version:
+ from psycopg.crdb import CrdbConnection
+
+ cls = CrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def aconn_cls(session_dsn):
+ cls = psycopg.AsyncConnection
+ if crdb_version:
+ from psycopg.crdb import AsyncCrdbConnection
+
+ cls = AsyncCrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def svcconn(conn_cls, session_dsn):
+ """
+ Return a session `Connection` connected to the ``--test-dsn`` database.
+ """
+ conn = conn_cls.connect(session_dsn, autocommit=True)
+ yield conn
+ conn.close()
+
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(conn, monkeypatch)
+
+
+@pytest.fixture
+def acommands(aconn, monkeypatch):
+ """The list of commands issued internally by the test async connection."""
+ yield patch_exec(aconn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+ """Helper to implement the commands fixture both sync and async."""
+ _orig_exec_command = conn._exec_command
+ L = ListPopAll()
+
+ def _exec_command(command, *args, **kwargs):
+ cmdcopy = command
+ if isinstance(cmdcopy, bytes):
+ cmdcopy = cmdcopy.decode(conn.info.encoding)
+ elif isinstance(cmdcopy, sql.Composable):
+ cmdcopy = cmdcopy.as_string(conn)
+
+ L.append(cmdcopy)
+ return _orig_exec_command(command, *args, **kwargs)
+
+ monkeypatch.setattr(conn, "_exec_command", _exec_command)
+ return L
+
+
+class ListPopAll(list): # type: ignore[type-arg]
+ """A list, with a popall() method."""
+
+ def popall(self):
+ out = self[:]
+ del self[:]
+ return out
+
+
+def check_connection_version(node):
+ try:
+ pg_version
+ except NameError:
+ # First connection creation failed. Let the tests fail.
+ pytest.fail("server version not available")
+
+ for mark in node.iter_markers():
+ if mark.name == "pg":
+ assert len(mark.args) == 1
+ msg = check_postgres_version(pg_version, mark.args[0])
+ if msg:
+ pytest.skip(msg)
+
+ elif mark.name in ("crdb", "crdb_skip"):
+ from .fix_crdb import check_crdb_version
+
+ msg = check_crdb_version(crdb_version, mark)
+ if msg:
+ pytest.skip(msg)
+
+
+@pytest.fixture
+def hstore(svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists hstore")
+ except psycopg.Error as e:
+ pytest.skip(str(e))
+
+
+@cache
+def warm_up_database(dsn: str) -> None:
+ """Connect to the database before returning a connection.
+
+ In the CI sometimes, the first test fails with a timeout, probably because
+ the server hasn't started yet. Absorb the delay before the test.
+
+ In case of error, abort the test run entirely, to avoid failing downstream
+ hundreds of times.
+ """
+ global pg_version, crdb_version
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ conn.execute("select 1")
+
+ pg_version = conn.info.server_version
+
+ crdb_version = None
+ param = conn.info.parameter_status("crdb_version")
+ if param:
+ from psycopg.crdb import CrdbConnectionInfo
+
+ crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
+ except Exception as exc:
+ pytest.exit(f"failed to connect to the test database: {exc}")