diff options
Diffstat (limited to 'tests/pool')
-rw-r--r-- | tests/pool/__init__.py | 0 | ||||
-rw-r--r-- | tests/pool/fix_pool.py | 12 | ||||
-rw-r--r-- | tests/pool/test_null_pool.py | 896 | ||||
-rw-r--r-- | tests/pool/test_null_pool_async.py | 844 | ||||
-rw-r--r-- | tests/pool/test_pool.py | 1265 | ||||
-rw-r--r-- | tests/pool/test_pool_async.py | 1198 | ||||
-rw-r--r-- | tests/pool/test_pool_async_noasyncio.py | 78 | ||||
-rw-r--r-- | tests/pool/test_sched.py | 154 | ||||
-rw-r--r-- | tests/pool/test_sched_async.py | 159 |
9 files changed, 4606 insertions, 0 deletions
diff --git a/tests/pool/__init__.py b/tests/pool/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/pool/__init__.py diff --git a/tests/pool/fix_pool.py b/tests/pool/fix_pool.py new file mode 100644 index 0000000..12e4f39 --- /dev/null +++ b/tests/pool/fix_pool.py @@ -0,0 +1,12 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "pool: test related to the psycopg_pool package") + + +def pytest_collection_modifyitems(items): + # Add the pool markers to all the tests in the pool package + for item in items: + if "/pool/" in item.nodeid: + item.add_marker(pytest.mark.pool) diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py new file mode 100644 index 0000000..c0e8060 --- /dev/null +++ b/tests/pool/test_null_pool.py @@ -0,0 +1,896 @@ +import logging +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus + +from .test_pool import delay_connection, ensure_waiting + +try: + from psycopg_pool import NullConnectionPool + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +def test_defaults(dsn): + with NullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +def test_min_size_max_size(dsn): + with NullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + NullConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with NullConnectionPool(dsn, connection_class=MyConn) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_no_pool_at_all(dsn): + with NullConnectionPool(dsn, max_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +def test_context(dsn): + with NullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.1) + + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.4) + + +def test_wait_closed(dsn): + with NullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + p.wait(0.2) + + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with NullConnectionPool(dsn, configure=configure) as p: + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 3 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + pids = [] + + def worker(): + with p.connection() as conn: + assert resets == 1 + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + assert resets == 0 + conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + t.join() + p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +def test_no_queue_timeout(deaf_port): + with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p: + with pytest.raises(PoolTimeout): + with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with NullConnectionPool(dsn, max_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with NullConnectionPool(dsn, max_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with NullConnectionPool(dsn, max_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + assert not conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + t = Thread(target=worker, args=(p,)) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with NullConnectionPool(dsn) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with NullConnectionPool(dsn) as p1: + with NullConnectionPool(dsn) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = NullConnectionPool(dsn) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = NullConnectionPool(dsn) + + with p.connection() as conn: + pass + assert conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = NullConnectionPool(dsn, max_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = NullConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +def test_reopen(dsn): + p = NullConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_resize(dsn, min_size, max_size): + with NullConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.1) + + ts = [] + with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + for i in range(5): + ts.append(Thread(target=worker, args=(p,))) + ts[-1].start() + + for t in ts: + t.join() + + assert pids[0] == pids[1] != pids[4], pids + + +def test_check(dsn): + with NullConnectionPool(dsn) as p: + # No-op + p.check() + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with NullConnectionPool(dsn, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + with NullConnectionPool(dsn, max_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with NullConnectionPool(proxy.client_dsn, max_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py new file mode 100644 index 0000000..23a1a52 --- /dev/null +++ b/tests/pool/test_null_pool_async.py @@ -0,0 +1,844 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task +from .test_pool_async import delay_connection, ensure_waiting + +pytestmark = [pytest.mark.asyncio] + +try: + from psycopg_pool import AsyncNullConnectionPool # noqa: F401 + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +async def test_defaults(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +async def test_min_size_max_size(dsn): + async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + AsyncNullConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_no_pool_at_all(dsn): + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +async def test_context(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.1) + + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.4) + + +async def test_wait_closed(dsn): + async with AsyncNullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await p.wait(0.2) + + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 3 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + pids = [] + + async def worker(): + async with p.connection() as conn: + assert resets == 1 + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + await p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +async def test_no_queue_timeout(deaf_port): + async with AsyncNullConnectionPool( + kwargs={"host": "localhost", "port": deaf_port} + ) as p: + with pytest.raises(PoolTimeout): + async with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + t = create_task(worker()) + await ensure_waiting(p) + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = AsyncNullConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with AsyncNullConnectionPool(dsn) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with AsyncNullConnectionPool(dsn) as p1: + async with AsyncNullConnectionPool(dsn) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = AsyncNullConnectionPool(dsn) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = AsyncNullConnectionPool(dsn) + + async with p.connection() as conn: + pass + assert conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = AsyncNullConnectionPool(dsn, max_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed): + await p.getconn() + + with pytest.raises(PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = AsyncNullConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +async def test_reopen(dsn): + p = AsyncNullConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with AsyncNullConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + pids: List[int] = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.1) + + async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + ts = [create_task(worker()) for i in range(5)] + await asyncio.gather(*ts) + + assert pids[0] == pids[1] != pids[4], pids + + +async def test_check(dsn): + # no.op + async with AsyncNullConnectionPool(dsn) as p: + await p.check() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with AsyncNullConnectionPool(dsn, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + async with AsyncNullConnectionPool(dsn, max_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py new file mode 100644 index 0000000..30c790b --- /dev/null +++ b/tests/pool/test_pool.py @@ -0,0 +1,1265 @@ +import logging +import weakref +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +def test_package_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg_pool import __version__ +assert __version__ +""" + ) + assert not cp.stdout + + +def test_defaults(dsn): + with pool.ConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +def test_min_size_max_size(dsn, min_size, max_size): + with pool.ConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.ConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with pool.ConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_really_a_pool(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +def test_context(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +def test_connection_not_lost(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + def add_time(self, conn): + times.append(time() - t0) + add_orig(self, conn) + + add_orig = pool.ConnectionPool._add_to_pool + monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + with pool.ConnectionPool(dsn, min_size=5, num_workers=2) as p: + p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.5) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=2) as p: + p.wait(0.3) + p.wait(0.0001) # idempotent + + +def test_wait_closed(dsn): + with pool.ConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + p.wait(0.2) + + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + p.wait() + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + assert resets == 0 + conn.execute("set timezone to '+2:00'") + + p.wait() + assert resets == 1 + + with p.connection() as conn: + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + + p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + assert not conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p1: + with pool.ConnectionPool(dsn, min_size=1) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +def test_del_no_warning(dsn, recwarn): + p = pool.ConnectionPool(dsn, min_size=2) + with p.connection() as conn: + conn.execute("select 1") + + p.wait() + ref = weakref.ref(p) + del p + assert not ref() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + + with p.connection() as conn: + pass + assert not conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = pool.ConnectionPool(dsn, min_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = pool.ConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.3) + finally: + p.close() + + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.5) + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.5) + + +def test_reopen(dsn): + p = pool.ConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p: + p.wait(1.0) + results: List[Tuple[int, float]] = [] + + ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool import ShrinkPool + + results: List[Tuple[int, int]] = [] + + def run_hacked(self, pool): + n0 = pool._nconns + orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.1)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + with pool.ConnectionPool(proxy.client_dsn, min_size=1) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + sleep(1.0) + proxy.start() + p.wait() + + with p.connection() as conn: + conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + with pool.ConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + t0 = time() + sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + with p.connection() as conn: + conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + assert ev.wait(timeout=2) + + with pytest.raises(pool.PoolTimeout): + with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + assert ev.wait(timeout=2) + + proxy.start() + + with p.connection(timeout=2) as conn: + conn.execute("select 1") + + p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +def test_refill_on_check(proxy): + proxy.start() + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + p.check() + assert ev.wait(timeout=2) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + p.check() + p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +def test_uniform_use(dsn): + with pool.ConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + with p.connection() as conn: + sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +def test_resize(dsn): + def sampler(): + sleep(0.05) # ensure sampling happens after shrink check + while True: + sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + def client(t): + with p.connection() as conn: + conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = Thread(target=sampler) + s.start() + + sleep(0.3) + c = Thread(target=client, args=(0.4,)) + c.start() + + sleep(0.2) + p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + sleep(0.4) + p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + sleep(0.6) + + s.join() + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +def test_bad_resize(dsn, min_size, max_size): + with pool.ConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +def test_jitter(): + rnds = [pool.ConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + with pool.ConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + sleep(0.1) + pids = [] + for i in range(5): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + with pool.ConnectionPool(dsn, min_size=4) as p: + p.wait(1.0) + with p.connection() as conn: + pid = conn.info.backend_pid + + p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + conn.close() + + assert len(caplog.records) == 0 + p.check() + assert len(caplog.records) == 1 + p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +def test_check_idle(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait(1.0) + p.check() + with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + with pool.ConnectionPool(dsn, min_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with pool.ConnectionPool(proxy.client_dsn, min_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 600 <= stats["connections_ms"] < 1200 + + proxy.stop() + p.check() + sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + def worker(): + with p.connection(): + sleep(0.002) + + with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p: + p.wait() + + ts = [Thread(target=worker) for i in range(50)] + for t in ts: + t.start() + for t in ts: + t.join() + p.wait() + + assert len(p._pool) < 7 + + +def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + with pool.ConnectionPool(dsn, min_size=4, open=True) as p: + try: + p.wait(timeout=2) + finally: + print(p.get_stats()) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + def connect_delay(*args, **kwargs): + t0 = time() + rv = connect_orig(*args, **kwargs) + t1 = time() + sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.Connection.connect + monkeypatch.setattr(psycopg.Connection, "connect", connect_delay) + + +def ensure_waiting(p, num=1): + """ + Wait until there are at least *num* clients waiting in the queue. + """ + while len(p._waiting) < num: + sleep(0) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py new file mode 100644 index 0000000..286a775 --- /dev/null +++ b/tests/pool/test_pool_async.py @@ -0,0 +1,1198 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task, Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio] + + +async def test_defaults(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +async def test_min_size_max_size(dsn, min_size, max_size): + async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.AsyncConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with pool.AsyncConnectionPool( + dsn, kwargs={"autocommit": True}, min_size=1 + ) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_really_a_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +async def test_context(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +async def test_connection_not_lost(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + async with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +async def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + async def add_time(self, conn): + times.append(time() - t0) + await add_orig(self, conn) + + add_orig = pool.AsyncConnectionPool._add_to_pool + monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + async with pool.AsyncConnectionPool(dsn, min_size=5, num_workers=2) as p: + await p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.5) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=2) as p: + await p.wait(0.3) + await p.wait(0.0001) # idempotent + + +async def test_wait_closed(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await p.wait(0.2) + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + await p.wait(timeout=1.0) + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + + await p.wait() + assert resets == 1 + + async with p.connection() as conn: + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + + await p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + cur = await conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p1: + async with pool.AsyncConnectionPool(dsn, min_size=1) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + + async with p.connection() as conn: + pass + assert not conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = pool.AsyncConnectionPool(dsn, min_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed): + await p.getconn() + + with pytest.raises(pool.PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = pool.AsyncConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.3) + finally: + await p.close() + + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.5) + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.5) + + +async def test_reopen(dsn): + p = pool.AsyncConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +async def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + async with pool.AsyncConnectionPool( + dsn, min_size=min_size, max_size=4, num_workers=3 + ) as p: + await p.wait(1.0) + ts = [] + results: List[Tuple[int, float]] = [] + + ts = [create_task(worker(i)) for i in range(len(want_times))] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool_async import ShrinkPool + + results: List[Tuple[int, int]] = [] + + async def run_hacked(self, pool): + n0 = pool._nconns + await orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.1)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + await p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + await asyncio.sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +async def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=1) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + await asyncio.sleep(1.0) + proxy.start() + await p.wait() + + async with p.connection() as conn: + await conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + t0 = time() + await asyncio.sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +async def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + await asyncio.wait_for(ev.wait(), 2.0) + + with pytest.raises(pool.PoolTimeout): + async with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + await asyncio.wait_for(ev.wait(), 2.0) + + proxy.start() + + async with p.connection(timeout=2) as conn: + await conn.execute("select 1") + + await p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +async def test_refill_on_check(proxy): + proxy.start() + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + await p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + await p.check() + await asyncio.wait_for(ev.wait(), 2.0) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + await p.check() + await p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +async def test_uniform_use(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + async with p.connection() as conn: + await asyncio.sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_resize(dsn): + async def sampler(): + await asyncio.sleep(0.05) # ensure sampling happens after shrink check + while True: + await asyncio.sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + async def client(t): + async with p.connection() as conn: + await conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = create_task(sampler()) + + await asyncio.sleep(0.3) + + c = create_task(client(0.4)) + + await asyncio.sleep(0.2) + await p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + await asyncio.sleep(0.4) + await p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + await asyncio.sleep(0.6) + + await asyncio.gather(s, c) + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with pool.AsyncConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +async def test_jitter(): + rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + await asyncio.sleep(0.1) + pids = [] + for i in range(5): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +async def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + await p.wait(1.0) + async with p.connection() as conn: + pid = conn.info.backend_pid + + await p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + await conn.close() + + assert len(caplog.records) == 0 + await p.check() + assert len(caplog.records) == 1 + await p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +async def test_check_idle(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait(1.0) + await p.check() + async with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + async with pool.AsyncConnectionPool(dsn, min_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 580 <= stats["connections_ms"] < 1200 + + proxy.stop() + await p.check() + await asyncio.sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +async def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + async def worker(): + async with p.connection(): + await asyncio.sleep(0.002) + + async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p: + await p.wait() + + ts = [create_task(worker()) for i in range(50)] + await asyncio.gather(*ts) + await p.wait() + + assert len(p._pool) < 7 + + +async def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + async with pool.AsyncConnectionPool(dsn, min_size=4, open=True) as p: + await p.wait(timeout=2) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + async def connect_delay(*args, **kwargs): + t0 = time() + rv = await connect_orig(*args, **kwargs) + t1 = time() + await asyncio.sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.AsyncConnection.connect + monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay) + + +async def ensure_waiting(p, num=1): + while len(p._waiting) < num: + await asyncio.sleep(0) diff --git a/tests/pool/test_pool_async_noasyncio.py b/tests/pool/test_pool_async_noasyncio.py new file mode 100644 index 0000000..f6e34e4 --- /dev/null +++ b/tests/pool/test_pool_async_noasyncio.py @@ -0,0 +1,78 @@ +# These tests relate to AsyncConnectionPool, but are not marked asyncio +# because they rely on the pool initialization outside the asyncio loop. + +import asyncio + +import pytest + +from ..utils import gc_collect + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +@pytest.mark.slow +def test_reconnect_after_max_lifetime(dsn, asyncio_run): + # See issue #219, pool created before the loop. + p = pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +@pytest.mark.slow +def test_working_created_before_loop(dsn, asyncio_run): + p = pool.AsyncNullConnectionPool(dsn, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +def test_cant_create_open_outside_loop(dsn): + with pytest.raises(RuntimeError): + pool.AsyncConnectionPool(dsn, open=True) + + +@pytest.fixture +def asyncio_run(recwarn): + """Fixture reuturning asyncio.run, but managing resources at exit. + + In certain runs, fd objects are leaked and the error will only be caught + downstream, by some innocent test calling gc_collect(). + """ + recwarn.clear() + try: + yield asyncio.run + finally: + gc_collect() + if recwarn: + warn = recwarn.pop(ResourceWarning) + assert "unclosed event loop" in str(warn.message) + assert not recwarn diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py new file mode 100644 index 0000000..b3d2572 --- /dev/null +++ b/tests/pool/test_sched.py @@ -0,0 +1,154 @@ +import logging +from time import time, sleep +from functools import partial +from threading import Thread + +import pytest + +try: + from psycopg_pool.sched import Scheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.timing] + + +@pytest.mark.slow +def test_sched(): + s = Scheduler() + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +def test_sched_thread(): + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + def error(): + 1 / 0 + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, None) + s.enter(0.3, partial(worker, 2)) + s.enter(0.2, error) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +def test_empty_queue_timeout(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = Thread(target=s.run) + t.start() + sleep(0.5) + s.enter(0.5, None) + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +def test_first_task_rescheduling(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + s.enter(0.4, lambda: None) + t = Thread(target=s.run) + t.start() + s.enter(0.6, None) # this task doesn't trigger a reschedule + sleep(0.1) + s.enter(0.1, lambda: None) # this triggers a reschedule + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py new file mode 100644 index 0000000..492d620 --- /dev/null +++ b/tests/pool/test_sched_async.py @@ -0,0 +1,159 @@ +import asyncio +import logging +from time import time +from functools import partial + +import pytest + +from psycopg._compat import create_task + +try: + from psycopg_pool.sched import AsyncScheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio, pytest.mark.timing] + + +@pytest.mark.slow +async def test_sched(): + s = AsyncScheduler() + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + await s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +async def test_sched_task(): + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +async def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + async def error(): + 1 / 0 + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, None) + await s.enter(0.3, partial(worker, 2)) + await s.enter(0.2, error) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +async def test_empty_queue_timeout(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = create_task(s.run()) + await asyncio.sleep(0.5) + await s.enter(0.5, None) + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +async def test_first_task_rescheduling(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + async def noop(): + pass + + await s.enter(0.4, noop) + t = create_task(s.run()) + await s.enter(0.6, None) # this task doesn't trigger a reschedule + await asyncio.sleep(0.1) + await s.enter(0.1, noop) # this triggers a reschedule + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times |