summaryrefslogtreecommitdiffstats
path: root/tests/test_tpc_async.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
commit506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch)
tree808913770c5e6935d3714058c2a066c57b4632ec /tests/test_tpc_async.py
parentInitial commit. (diff)
downloadpsycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz
psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_tpc_async.py')
-rw-r--r--tests/test_tpc_async.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py
new file mode 100644
index 0000000..a409a2e
--- /dev/null
+++ b/tests/test_tpc_async.py
@@ -0,0 +1,310 @@
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("2-phase commit"),
+]
+
+
+async def test_tpc_disabled(aconn, apipeline):
+ cur = await aconn.execute("show max_prepared_transactions")
+ val = int((await cur.fetchone())[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ await aconn.rollback()
+ await aconn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ await aconn.tpc_prepare()
+
+
+class TestTPC:
+ async def test_tpc_commit(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_commit(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_rollback(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_rollback(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_status_after_recover(self, aconn, tpc):
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ async def test_recovered_xids(self, aconn, tpc):
+ # insert a few test xns
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("begin; prepare transaction '1-foo'")
+ await cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ await cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (aconn.info.dbname,),
+ )
+ okvals = await cur.fetchall()
+ okvals.sort()
+
+ xids = await aconn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == aconn.info.dbname]
+ xids.sort(key=lambda x: x.gtrid)
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ async def test_xid_encoding(self, aconn, tpc):
+ xid = aconn.xid(42, "gtrid", "bqual")
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+
+ cur = aconn.cursor()
+ await cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (aconn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual):
+ xid = aconn.xid(fid, gtrid, bqual)
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid):
+ await aconn.tpc_begin(tid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc):
+ x1 = aconn.xid(10, "uni", "code")
+ await aconn.tpc_begin(x1)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ await aconn.execute("set client_encoding to utf8")
+ await aconn.commit()
+
+ await aconn.tpc_begin("transaction-id")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ async def test_cancel_fails_prepared(self, aconn, tpc):
+ await aconn.tpc_begin("cancel")
+ await aconn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ aconn.cancel()
+
+ async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc):
+ aconn.row_factory = psycopg.rows.dict_row
+ await aconn.tpc_begin("dict-connection")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = await aconn.tpc_recover()
+ xid = [x for x in xids if x.database == aconn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None