summaryrefslogtreecommitdiffstats
path: root/tests/test_concurrency_async.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_concurrency_async.py')
-rw-r--r--tests/test_concurrency_async.py242
1 files changed, 242 insertions, 0 deletions
diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py
new file mode 100644
index 0000000..29b08cf
--- /dev/null
+++ b/tests/test_concurrency_async.py
@@ -0,0 +1,242 @@
+import sys
+import time
+import signal
+import asyncio
+import subprocess as sp
+from asyncio.queues import Queue
+from typing import List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg import errors as e
+from psycopg._compat import create_task
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.slow
+async def test_commit_concurrency(aconn):
+ # Check the condition reported in psycopg2#103
+ # Because of bad status check, we commit even when a commit is already on
+ # its way. We can detect this condition by the warnings.
+ notices = Queue() # type: ignore[var-annotated]
+ aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary))
+ stop = False
+
+ async def committer():
+ nonlocal stop
+ while not stop:
+ await aconn.commit()
+ await asyncio.sleep(0) # Allow the other worker to work
+
+ async def runner():
+ nonlocal stop
+ cur = aconn.cursor()
+ for i in range(1000):
+ await cur.execute("select %s;", (i,))
+ await aconn.commit()
+
+ # Stop the committer thread
+ stop = True
+
+ await asyncio.gather(committer(), runner())
+ assert notices.empty(), "%d notices raised" % notices.qsize()
+
+
+@pytest.mark.slow
+async def test_concurrent_execution(aconn_cls, dsn):
+ async def worker():
+ cnn = await aconn_cls.connect(dsn)
+ cur = cnn.cursor()
+ await cur.execute("select pg_sleep(0.5)")
+ await cur.close()
+ await cnn.close()
+
+ workers = [worker(), worker()]
+ t0 = time.time()
+ await asyncio.gather(*workers)
+ assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("notify")
+async def test_notifies(aconn_cls, aconn, dsn):
+ nconn = await aconn_cls.connect(dsn, autocommit=True)
+ npid = nconn.pgconn.backend_pid
+
+ async def notifier():
+ cur = nconn.cursor()
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '1'")
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '2'")
+ await nconn.close()
+
+ async def receiver():
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ gen = aconn.notifies()
+ async for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ await gen.aclose()
+
+ ns: List[Tuple[psycopg.Notify, float]] = []
+ t0 = time.time()
+ workers = [notifier(), receiver()]
+ await asyncio.gather(*workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+async def canceller(aconn, errors):
+ try:
+ await asyncio.sleep(0.5)
+ aconn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ await cur.execute("select pg_sleep(2)")
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel_stream(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ async for row in cur.stream("select pg_sleep(2)"):
+ pass
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_identify_closure(aconn_cls, dsn):
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute(
+ "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid]
+ )
+
+ aconn = await aconn_cls.connect(dsn)
+ conn2 = await aconn_cls.connect(dsn)
+ try:
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("select pg_sleep(1.0)")
+ t1 = time.time()
+ assert 0.2 < t1 - t0 < 0.4
+ finally:
+ await asyncio.gather(t)
+ finally:
+ await aconn.close()
+ await conn2.close()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb_skip("cancel")
+async def test_ctrl_c(dsn):
+ script = f"""\
+import signal
+import asyncio
+import psycopg
+
+async def main():
+ ctrl_c = False
+ loop = asyncio.get_event_loop()
+ async with await psycopg.AsyncConnection.connect({dsn!r}) as conn:
+ loop.add_signal_handler(signal.SIGINT, conn.cancel)
+ cur = conn.cursor()
+ try:
+ await cur.execute("select pg_sleep(2)")
+ except psycopg.errors.QueryCanceled:
+ ctrl_c = True
+
+ assert ctrl_c, "ctrl-c not received"
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await conn.rollback()
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+asyncio.run(main())
+"""
+ if sys.platform == "win32":
+ creationflags = sp.CREATE_NEW_PROCESS_GROUP
+ sig = signal.CTRL_C_EVENT
+ else:
+ creationflags = 0
+ sig = signal.SIGINT
+
+ proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags)
+ with pytest.raises(sp.TimeoutExpired):
+ outs, errs = proc.communicate(timeout=1)
+
+ proc.send_signal(sig)
+ proc.communicate()
+ assert proc.returncode == 0