summaryrefslogtreecommitdiffstats
path: root/tests/fix_psycopg.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
commit506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch)
tree808913770c5e6935d3714058c2a066c57b4632ec /tests/fix_psycopg.py
parentInitial commit. (diff)
downloadpsycopg3-upstream.tar.xz
psycopg3-upstream.zip
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/fix_psycopg.py')
-rw-r--r--tests/fix_psycopg.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py
new file mode 100644
index 0000000..80e0c62
--- /dev/null
+++ b/tests/fix_psycopg.py
@@ -0,0 +1,98 @@
+from copy import deepcopy
+
+import pytest
+
+
+@pytest.fixture
+def global_adapters():
+ """Restore the global adapters after a test has changed them."""
+ from psycopg import adapters
+
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+@pytest.fixture
+@pytest.mark.crdb_skip("2-phase commit")
+def tpc(svcconn):
+ tpc = Tpc(svcconn)
+ tpc.check_tpc()
+ tpc.clear_test_xacts()
+ tpc.make_test_table()
+ yield tpc
+ tpc.clear_test_xacts()
+
+
+class Tpc:
+ """Helper object to test two-phase transactions"""
+
+ def __init__(self, conn):
+ assert conn.autocommit
+ self.conn = conn
+
+ def check_tpc(self):
+ from .fix_crdb import is_crdb, crdb_skip_message
+
+ if is_crdb(self.conn):
+ pytest.skip(crdb_skip_message("2-phase commit"))
+
+ val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0])
+ if not val:
+ pytest.skip("prepared transactions disabled in the database")
+
+ def clear_test_xacts(self):
+ """Rollback all the prepared transaction in the testing db."""
+ from psycopg import sql
+
+ cur = self.conn.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (self.conn.info.dbname,),
+ )
+ gids = [r[0] for r in cur]
+ for gid in gids:
+ self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+ def make_test_table(self):
+ self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+ self.conn.execute("TRUNCATE test_tpc")
+
+ def count_xacts(self):
+ """Return the number of prepared xacts currently in the test db."""
+ cur = self.conn.execute(
+ """
+ select count(*) from pg_prepared_xacts
+ where database = %s""",
+ (self.conn.info.dbname,),
+ )
+ return cur.fetchone()[0]
+
+ def count_test_records(self):
+ """Return the number of records in the test table."""
+ cur = self.conn.execute("select count(*) from test_tpc")
+ return cur.fetchone()[0]
+
+
+@pytest.fixture(scope="module")
+def generators():
+ """Return the 'generators' module for selected psycopg implementation."""
+ from psycopg import pq
+
+ if pq.__impl__ == "c":
+ from psycopg._cmodule import _psycopg
+
+ return _psycopg
+ else:
+ import psycopg.generators
+
+ return psycopg.generators