summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/README.rst94
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/adapters_example.py45
-rw-r--r--tests/conftest.py92
-rw-r--r--tests/constraints.txt32
-rw-r--r--tests/crdb/__init__.py0
-rw-r--r--tests/crdb/test_adapt.py78
-rw-r--r--tests/crdb/test_connection.py86
-rw-r--r--tests/crdb/test_connection_async.py85
-rw-r--r--tests/crdb/test_conninfo.py21
-rw-r--r--tests/crdb/test_copy.py233
-rw-r--r--tests/crdb/test_copy_async.py235
-rw-r--r--tests/crdb/test_cursor.py65
-rw-r--r--tests/crdb/test_cursor_async.py61
-rw-r--r--tests/crdb/test_no_crdb.py34
-rw-r--r--tests/crdb/test_typing.py49
-rw-r--r--tests/dbapi20.py870
-rw-r--r--tests/dbapi20_tpc.py151
-rw-r--r--tests/fix_crdb.py131
-rw-r--r--tests/fix_db.py358
-rw-r--r--tests/fix_faker.py868
-rw-r--r--tests/fix_mypy.py54
-rw-r--r--tests/fix_pq.py141
-rw-r--r--tests/fix_proxy.py127
-rw-r--r--tests/fix_psycopg.py98
-rw-r--r--tests/pool/__init__.py0
-rw-r--r--tests/pool/fix_pool.py12
-rw-r--r--tests/pool/test_null_pool.py896
-rw-r--r--tests/pool/test_null_pool_async.py844
-rw-r--r--tests/pool/test_pool.py1265
-rw-r--r--tests/pool/test_pool_async.py1198
-rw-r--r--tests/pool/test_pool_async_noasyncio.py78
-rw-r--r--tests/pool/test_sched.py154
-rw-r--r--tests/pool/test_sched_async.py159
-rw-r--r--tests/pq/__init__.py0
-rw-r--r--tests/pq/test_async.py210
-rw-r--r--tests/pq/test_conninfo.py48
-rw-r--r--tests/pq/test_copy.py174
-rw-r--r--tests/pq/test_escaping.py188
-rw-r--r--tests/pq/test_exec.py146
-rw-r--r--tests/pq/test_misc.py83
-rw-r--r--tests/pq/test_pgconn.py585
-rw-r--r--tests/pq/test_pgresult.py207
-rw-r--r--tests/pq/test_pipeline.py161
-rw-r--r--tests/pq/test_pq.py57
-rw-r--r--tests/scripts/bench-411.py300
-rw-r--r--tests/scripts/dectest.py51
-rw-r--r--tests/scripts/pipeline-demo.py340
-rw-r--r--tests/scripts/spiketest.py156
-rw-r--r--tests/test_adapt.py530
-rw-r--r--tests/test_client_cursor.py855
-rw-r--r--tests/test_client_cursor_async.py727
-rw-r--r--tests/test_concurrency.py327
-rw-r--r--tests/test_concurrency_async.py242
-rw-r--r--tests/test_connection.py790
-rw-r--r--tests/test_connection_async.py751
-rw-r--r--tests/test_conninfo.py450
-rw-r--r--tests/test_copy.py889
-rw-r--r--tests/test_copy_async.py892
-rw-r--r--tests/test_cursor.py942
-rw-r--r--tests/test_cursor_async.py802
-rw-r--r--tests/test_dns.py27
-rw-r--r--tests/test_dns_srv.py149
-rw-r--r--tests/test_encodings.py57
-rw-r--r--tests/test_errors.py309
-rw-r--r--tests/test_generators.py156
-rw-r--r--tests/test_module.py57
-rw-r--r--tests/test_pipeline.py577
-rw-r--r--tests/test_pipeline_async.py586
-rw-r--r--tests/test_prepared.py277
-rw-r--r--tests/test_prepared_async.py207
-rw-r--r--tests/test_psycopg_dbapi20.py164
-rw-r--r--tests/test_query.py162
-rw-r--r--tests/test_rows.py167
-rw-r--r--tests/test_server_cursor.py525
-rw-r--r--tests/test_server_cursor_async.py543
-rw-r--r--tests/test_sql.py604
-rw-r--r--tests/test_tpc.py325
-rw-r--r--tests/test_tpc_async.py310
-rw-r--r--tests/test_transaction.py796
-rw-r--r--tests/test_transaction_async.py743
-rw-r--r--tests/test_typeinfo.py145
-rw-r--r--tests/test_typing.py449
-rw-r--r--tests/test_waiting.py159
-rw-r--r--tests/test_windows.py23
-rw-r--r--tests/types/__init__.py0
-rw-r--r--tests/types/test_array.py338
-rw-r--r--tests/types/test_bool.py47
-rw-r--r--tests/types/test_composite.py396
-rw-r--r--tests/types/test_datetime.py813
-rw-r--r--tests/types/test_enum.py363
-rw-r--r--tests/types/test_hstore.py107
-rw-r--r--tests/types/test_json.py182
-rw-r--r--tests/types/test_multirange.py434
-rw-r--r--tests/types/test_net.py135
-rw-r--r--tests/types/test_none.py12
-rw-r--r--tests/types/test_numeric.py625
-rw-r--r--tests/types/test_range.py677
-rw-r--r--tests/types/test_shapely.py152
-rw-r--r--tests/types/test_string.py307
-rw-r--r--tests/types/test_uuid.py56
-rw-r--r--tests/typing_example.py176
-rw-r--r--tests/utils.py179
103 files changed, 32033 insertions, 0 deletions
diff --git a/tests/README.rst b/tests/README.rst
new file mode 100644
index 0000000..63c7238
--- /dev/null
+++ b/tests/README.rst
@@ -0,0 +1,94 @@
+psycopg test suite
+===================
+
+Quick version
+-------------
+
+To run tests on the current code you can install the `test` extra of the
+package, specify a connection string in the `PSYCOPG_TEST_DSN` env var to
+connect to a test database, and run ``pytest``::
+
+ $ pip install -e "psycopg[test]"
+ $ export PSYCOPG_TEST_DSN="host=localhost dbname=psycopg_test"
+ $ pytest
+
+
+Test options
+------------
+
+- The tests output header shows additional psycopg related information,
+ on top of the one normally displayed by ``pytest`` and the extensions used::
+
+ $ pytest
+ ========================= test session starts =========================
+ platform linux -- Python 3.8.5, pytest-6.0.2, py-1.10.0, pluggy-0.13.1
+ Using --randomly-seed=2416596601
+ libpq available: 130002
+ libpq wrapper implementation: c
+
+
+- By default the tests run using the ``pq`` implementation that psycopg would
+ choose (the C module if installed, else the Python one). In order to test a
+ different implementation, use the normal `pq module selection mechanism`__
+ of the ``PSYCOPG_IMPL`` env var::
+
+ $ PSYCOPG_IMPL=python pytest
+ ========================= test session starts =========================
+ [...]
+ libpq available: 130002
+ libpq wrapper implementation: python
+
+ .. __: https://www.psycopg.org/psycopg/docs/api/pq.html#pq-module-implementations
+
+
+- Slow tests have a ``slow`` marker which can be selected to reduce test
+ runtime to a few seconds only. Please add a ``@pytest.mark.slow`` marker to
+ any test needing an arbitrary wait. At the time of writing::
+
+ $ pytest
+ ========================= test session starts =========================
+ [...]
+ ======= 1983 passed, 3 skipped, 110 xfailed in 78.21s (0:01:18) =======
+
+ $ pytest -m "not slow"
+ ========================= test session starts =========================
+ [...]
+ ==== 1877 passed, 2 skipped, 169 deselected, 48 xfailed in 13.47s =====
+
+- ``pytest`` option ``--pq-trace={TRACEFILE,STDERR}`` can be used to capture
+ libpq trace. When using ``stderr``, the output will only be shown for
+ failing or in-error tests, unless ``-s/--capture=no`` option is used.
+
+- ``pytest`` option ``--pq-debug`` can be used to log access to libpq's
+ ``PGconn`` functions.
+
+
+Testing in docker
+-----------------
+
+Useful to test different Python versions without installing them. Can be used
+to replicate GitHub actions failures, specifying the ``--randomly-seed`` used
+in the test run. The following ``PG*`` env vars are an example to adjust the
+test dsn in order to connect to a database running on the docker host: specify
+a set of env vars working for your setup::
+
+ $ docker run -ti --rm --volume `pwd`:/src --workdir /src \
+ -e PSYCOPG_TEST_DSN -e PGHOST=172.17.0.1 -e PGUSER=`whoami` \
+ python:3.7 bash
+
+ # pip install -e "./psycopg[test]" ./psycopg_pool ./psycopg_c
+ # pytest
+
+
+Testing with CockroachDB
+========================
+
+You can run CRDB in a docker container using::
+
+ docker run -p 26257:26257 --name crdb --rm \
+ cockroachdb/cockroach:v22.1.3 start-single-node --insecure
+
+And use the following connection string to run the tests::
+
+ export PSYCOPG_TEST_DSN="host=localhost port=26257 user=root dbname=defaultdb"
+ pytest ...
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/adapters_example.py b/tests/adapters_example.py
new file mode 100644
index 0000000..a184e6a
--- /dev/null
+++ b/tests/adapters_example.py
@@ -0,0 +1,45 @@
+from typing import Optional
+
+from psycopg import pq
+from psycopg.abc import Dumper, Loader, AdaptContext, PyFormat, Buffer
+
+
+def f() -> None:
+ d: Dumper = MyStrDumper(str, None)
+ assert d.dump("abc") == b"abcabc"
+ assert d.quote("abc") == b"'abcabc'"
+
+ lo: Loader = MyTextLoader(0, None)
+ assert lo.load(b"abc") == "abcabc"
+
+
+class MyStrDumper:
+ format = pq.Format.TEXT
+ oid = 25 # text
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ self._cls = cls
+
+ def dump(self, obj: str) -> bytes:
+ return (obj * 2).encode()
+
+ def quote(self, obj: str) -> bytes:
+ value = self.dump(obj)
+ esc = pq.Escaping()
+ return b"'%s'" % esc.escape_string(value.replace(b"h", b"q"))
+
+ def get_key(self, obj: str, format: PyFormat) -> type:
+ return self._cls
+
+ def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper":
+ return self
+
+
+class MyTextLoader:
+ format = pq.Format.TEXT
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ pass
+
+ def load(self, data: Buffer) -> str:
+ return (bytes(data) * 2).decode()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..15bcf40
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,92 @@
+import sys
+import asyncio
+import selectors
+from typing import List
+
+pytest_plugins = (
+ "tests.fix_db",
+ "tests.fix_pq",
+ "tests.fix_mypy",
+ "tests.fix_faker",
+ "tests.fix_proxy",
+ "tests.fix_psycopg",
+ "tests.fix_crdb",
+ "tests.pool.fix_pool",
+)
+
+
+def pytest_configure(config):
+ markers = [
+ "slow: this test is kinda slow (skip with -m 'not slow')",
+ "flakey(reason): this test may fail unpredictably')",
+ # There are troubles on travis with these kind of tests and I cannot
+ # catch the exception for my life.
+ "subprocess: the test import psycopg after subprocess",
+ "timing: the test is timing based and can fail on cheese hardware",
+ "dns: the test requires dnspython to run",
+ "postgis: the test requires the PostGIS extension to run",
+ ]
+
+ for marker in markers:
+ config.addinivalue_line("markers", marker)
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--loop",
+ choices=["default", "uvloop"],
+ default="default",
+ help="The asyncio loop to use for async tests.",
+ )
+
+
+def pytest_report_header(config):
+ rv = []
+
+ rv.append(f"default selector: {selectors.DefaultSelector.__name__}")
+ loop = config.getoption("--loop")
+ if loop != "default":
+ rv.append(f"asyncio loop: {loop}")
+
+ return rv
+
+
+def pytest_sessionstart(session):
+ # Detect if there was a segfault in the previous run.
+ #
+ # In case of segfault, pytest doesn't get a chance to write failed tests
+ # in the cache. As a consequence, retries would find no test failed and
+ # assume that all tests passed in the previous run, making the whole test pass.
+ cache = session.config.cache
+ if cache.get("segfault", False):
+ session.warn(Warning("Previous run resulted in segfault! Not running any test"))
+ session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)"))
+ raise session.Failed
+ cache.set("segfault", True)
+
+ # Configure the async loop.
+ loop = session.config.getoption("--loop")
+ if loop == "uvloop":
+ import uvloop
+
+ uvloop.install()
+ else:
+ assert loop == "default"
+
+ if sys.platform == "win32":
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
+
+allow_fail_messages: List[str] = []
+
+
+def pytest_sessionfinish(session, exitstatus):
+ # Mark the test run successful (in the sense -weak- that we didn't segfault).
+ session.config.cache.set("segfault", False)
+
+
+def pytest_terminal_summary(terminalreporter, exitstatus, config):
+ if allow_fail_messages:
+ terminalreporter.section("failed tests ignored")
+ for msg in allow_fail_messages:
+ terminalreporter.line(msg)
diff --git a/tests/constraints.txt b/tests/constraints.txt
new file mode 100644
index 0000000..ef03ba1
--- /dev/null
+++ b/tests/constraints.txt
@@ -0,0 +1,32 @@
+# This is a constraint file forcing the minimum allowed version to be
+# installed.
+#
+# https://pip.pypa.io/en/stable/user_guide/#constraints-files
+
+# From install_requires
+backports.zoneinfo == 0.2.0
+typing-extensions == 4.1.0
+
+# From the 'test' extra
+mypy == 0.981
+pproxy == 2.7.0
+pytest == 6.2.5
+pytest-asyncio == 0.17.0
+pytest-cov == 3.0.0
+pytest-randomly == 3.10.0
+
+# From the 'dev' extra
+black == 22.3.0
+dnspython == 2.1.0
+flake8 == 4.0.0
+mypy == 0.981
+types-setuptools == 57.4.0
+wheel == 0.37
+
+# From the 'docs' extra
+Sphinx == 4.2.0
+furo == 2021.11.23
+sphinx-autobuild == 2021.3.14
+sphinx-autodoc-typehints == 1.12.0
+dnspython == 2.1.0
+shapely == 1.7.0
diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/crdb/__init__.py
diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py
new file mode 100644
index 0000000..ce5bacf
--- /dev/null
+++ b/tests/crdb/test_adapt.py
@@ -0,0 +1,78 @@
+from copy import deepcopy
+
+import pytest
+
+from psycopg.crdb import adapters, CrdbConnection
+
+from psycopg.adapt import PyFormat, Transformer
+from psycopg.types.array import ListDumper
+from psycopg.postgres import types as builtins
+
+from ..test_adapt import MyStr, make_dumper, make_bin_dumper
+from ..test_adapt import make_loader, make_bin_loader
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as text oid to CockroachDB, unlike Postgres,
+ # to which strings are passed as unknown. This is because CRDB doesn't
+ # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises
+ # an error. However, unlike PostgreSQL, text can be cast to any other type.
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+
+
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+@pytest.fixture
+def crdb_adapters():
+ """Restore the crdb adapters after a test has changed them."""
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+def test_dump_global_ctx(dsn, crdb_adapters, pgconn):
+ adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ adapters.register_dumper(MyStr, make_dumper("gt"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_load_global_ctx(dsn, crdb_adapters):
+ adapters.register_loader("text", make_loader("gt"))
+ adapters.register_loader("text", make_bin_loader("gb"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py
new file mode 100644
index 0000000..b2a69ef
--- /dev/null
+++ b/tests/crdb/test_connection.py
@@ -0,0 +1,86 @@
+import time
+import threading
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_is_crdb(conn):
+ assert CrdbConnection.is_crdb(conn)
+ assert CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_connect(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+ with psycopg.crdb.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+
+def test_xid(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+def test_tpc_begin(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_begin("foo")
+
+
+def test_tpc_recover(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_recover()
+
+
+@pytest.mark.slow
+def test_broken_connection(conn):
+ cur = conn.cursor()
+ (session_id,) = cur.execute("select session_id from [show session_id]").fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("cancel session %s", [session_id])
+ assert conn.closed
+
+
+@pytest.mark.slow
+def test_broken(conn):
+ (session_id,) = conn.execute("show session_id").fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("cancel session %s", [session_id])
+
+ assert conn.closed
+ assert conn.broken
+ conn.close()
+ assert conn.closed
+ assert conn.broken
+
+
+@pytest.mark.slow
+def test_identify_closure(conn_cls, dsn):
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ with conn_cls.connect(dsn, autocommit=True) as conn2:
+ (session_id,) = conn.execute("show session_id").fetchone()
+
+ def closer():
+ time.sleep(0.2)
+ conn2.execute("cancel session %s", [session_id])
+
+ t = threading.Thread(target=closer)
+ t.start()
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ # CRDB seems to take not less than 1s
+ assert 0.2 < dt < 2
+ finally:
+ t.join()
diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py
new file mode 100644
index 0000000..b568e42
--- /dev/null
+++ b/tests/crdb/test_connection_async.py
@@ -0,0 +1,85 @@
+import time
+import asyncio
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import AsyncCrdbConnection
+from psycopg._compat import create_task
+
+import pytest
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+async def test_is_crdb(aconn):
+ assert AsyncCrdbConnection.is_crdb(aconn)
+ assert AsyncCrdbConnection.is_crdb(aconn.pgconn)
+
+
+async def test_connect(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
+
+
+async def test_xid(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+async def test_tpc_begin(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_begin("foo")
+
+
+async def test_tpc_recover(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_recover()
+
+
+@pytest.mark.slow
+async def test_broken_connection(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select session_id from [show session_id]")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("cancel session %s", [session_id])
+ assert aconn.closed
+
+
+@pytest.mark.slow
+async def test_broken(aconn):
+ cur = await aconn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("cancel session %s", [session_id])
+
+ assert aconn.closed
+ assert aconn.broken
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.broken
+
+
+@pytest.mark.slow
+async def test_identify_closure(aconn_cls, dsn):
+ async with await aconn_cls.connect(dsn) as conn:
+ async with await aconn_cls.connect(dsn) as conn2:
+ cur = await conn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute("cancel session %s", [session_id])
+
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ assert 0.2 < dt < 2
+ finally:
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py
new file mode 100644
index 0000000..274a0c0
--- /dev/null
+++ b/tests/crdb/test_conninfo.py
@@ -0,0 +1,21 @@
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_vendor(conn):
+ assert conn.info.vendor == "CockroachDB"
+
+
+def test_server_version(conn):
+ assert conn.info.server_version > 200000
+
+
+@pytest.mark.crdb("< 22")
+def test_backend_pid_pre_22(conn):
+ assert conn.info.backend_pid == 0
+
+
+@pytest.mark.crdb(">= 22")
+def test_backend_pid(conn):
+ assert conn.info.backend_pid > 0
diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py
new file mode 100644
index 0000000..b7d26aa
--- /dev/null
+++ b/tests/crdb/test_copy.py
@@ -0,0 +1,233 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg import sql, errors as e
+from psycopg.pq import Format
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import ensure_table, sample_records
+from ..test_copy import sample_tabledef as sample_tabledef_pg
+
+# CRDB int/serial are int8
+sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin with binary") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+def test_copy_big_size_block(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select col2, data from copy_in order by 2").fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def copyopt(format):
+ return "with binary" if format == Format.BINARY else ""
diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py
new file mode 100644
index 0000000..d5fbf50
--- /dev/null
+++ b/tests/crdb/test_copy_async.py
@@ -0,0 +1,235 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg.pq import Format
+from psycopg import sql, errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import sample_records
+from ..test_copy_async import ensure_table
+from .test_copy import sample_tabledef, copyopt
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_str(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text.decode())
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+async def test_copy_in_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ async with cur.copy("copy copy_in from stdin with binary") as copy:
+ await copy.write(sample_text.decode())
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+async def test_copy_big_size_record(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write_row([data])
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+@pytest.mark.slow
+async def test_copy_big_size_block(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write(copy_data)
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_binary(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin {copyopt(format)}"
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select col2, data from copy_in order by 2")
+ data = await cur.fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ await copy.write_row(row)
+
+ await cur.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py
new file mode 100644
index 0000000..991b084
--- /dev/null
+++ b/tests/crdb/test_cursor.py
@@ -0,0 +1,65 @@
+import json
+import threading
+from uuid import uuid4
+from queue import Queue
+from typing import Any
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.fixture
+def testfeed(svcconn):
+ name = f"test_feed_{str(uuid4()).replace('-', '')}"
+ svcconn.execute("set cluster setting kv.rangefeed.enabled to true")
+ svcconn.execute(f"create table {name} (id serial primary key, data text)")
+ yield name
+ svcconn.execute(f"drop table {name}")
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
+ conn.autocommit = True
+ q: "Queue[Any]" = Queue()
+
+ def worker():
+ try:
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ for row in cur.stream(f"experimental changefeed for {testfeed}"):
+ q.put(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put(None)
+ except Exception as ex:
+ q.put(ex)
+
+ t = threading.Thread(target=worker)
+ t.start()
+
+ cur = conn.cursor()
+ cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = cur.fetchone()
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = cur.fetchone()
+ cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert q.get(timeout=1) is None
+ t.join()
diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py
new file mode 100644
index 0000000..229295d
--- /dev/null
+++ b/tests/crdb/test_cursor_async.py
@@ -0,0 +1,61 @@
+import json
+import asyncio
+from typing import Any
+from asyncio.queues import Queue
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+from psycopg._compat import create_task
+
+from .test_cursor import testfeed
+
+testfeed # fixture
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
+ await aconn.set_autocommit(True)
+ q: "Queue[Any]" = Queue()
+
+ async def worker():
+ try:
+ async with await aconn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ async for row in cur.stream(
+ f"experimental changefeed for {testfeed}"
+ ):
+ q.put_nowait(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put_nowait(None)
+ except Exception as ex:
+ q.put_nowait(ex)
+
+ t = create_task(worker())
+
+ cur = aconn.cursor()
+ await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = await cur.fetchone()
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ await cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ await cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = await cur.fetchone()
+ await cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert await asyncio.wait_for(q.get(), 1.0) is None
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py
new file mode 100644
index 0000000..df43f3b
--- /dev/null
+++ b/tests/crdb/test_no_crdb.py
@@ -0,0 +1,34 @@
+from psycopg.pq import TransactionStatus
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb("skip")
+
+
+def test_is_crdb(conn):
+ assert not CrdbConnection.is_crdb(conn)
+ assert not CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_tpc_on_pg_connection(conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py
new file mode 100644
index 0000000..2cff0a7
--- /dev/null
+++ b/tests/crdb/test_typing.py
@@ -0,0 +1,49 @@
+import pytest
+
+from ..test_typing import _test_reveal
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.crdb.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect()",
+ "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal_crdb(stmts, type, mypy)
+
+
+def _test_reveal_crdb(stmts, type, mypy):
+ stmts = f"""\
+import psycopg.crdb
+{stmts}
+"""
+ _test_reveal(stmts, type, mypy)
diff --git a/tests/dbapi20.py b/tests/dbapi20.py
new file mode 100644
index 0000000..c873a4e
--- /dev/null
+++ b/tests/dbapi20.py
@@ -0,0 +1,870 @@
+#!/usr/bin/env python
+# flake8: noqa
+# fmt: off
+''' Python DB API 2.0 driver compliance unit test suite.
+
+ This software is Public Domain and may be used without restrictions.
+
+ "Now we have booze and barflies entering the discussion, plus rumours of
+ DBAs on drugs... and I won't tell you what flashes through my mind each
+ time I read the subject line with 'Anal Compliance' in it. All around
+ this is turning out to be a thoroughly unwholesome unit test."
+
+ -- Ian Bicking
+'''
+
+__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
+__version__ = '$Revision: 1.12 $'[11:-2]
+__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
+
+import unittest
+import time
+import sys
+from typing import Any, Dict
+
+
+# Revision 1.12 2009/02/06 03:35:11 kf7xm
+# Tested okay with Python 3.0, includes last minute patches from Mark H.
+#
+# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole
+# Include latest changes from main branch
+# Updates for py3k
+#
+# Revision 1.11 2005/01/02 02:41:01 zenzen
+# Update author email address
+#
+# Revision 1.10 2003/10/09 03:14:14 zenzen
+# Add test for DB API 2.0 optional extension, where database exceptions
+# are exposed as attributes on the Connection object.
+#
+# Revision 1.9 2003/08/13 01:16:36 zenzen
+# Minor tweak from Stefan Fleiter
+#
+# Revision 1.8 2003/04/10 00:13:25 zenzen
+# Changes, as per suggestions by M.-A. Lemburg
+# - Add a table prefix, to ensure namespace collisions can always be avoided
+#
+# Revision 1.7 2003/02/26 23:33:37 zenzen
+# Break out DDL into helper functions, as per request by David Rushby
+#
+# Revision 1.6 2003/02/21 03:04:33 zenzen
+# Stuff from Henrik Ekelund:
+# added test_None
+# added test_nextset & hooks
+#
+# Revision 1.5 2003/02/17 22:08:43 zenzen
+# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
+# defaults to 1 & generic cursor.callproc test added
+#
+# Revision 1.4 2003/02/15 00:16:33 zenzen
+# Changes, as per suggestions and bug reports by M.-A. Lemburg,
+# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
+# - Class renamed
+# - Now a subclass of TestCase, to avoid requiring the driver stub
+# to use multiple inheritance
+# - Reversed the polarity of buggy test in test_description
+# - Test exception hierarchy correctly
+# - self.populate is now self._populate(), so if a driver stub
+# overrides self.ddl1 this change propagates
+# - VARCHAR columns now have a width, which will hopefully make the
+# DDL even more portible (this will be reversed if it causes more problems)
+# - cursor.rowcount being checked after various execute and fetchXXX methods
+# - Check for fetchall and fetchmany returning empty lists after results
+# are exhausted (already checking for empty lists if select retrieved
+# nothing
+# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
+#
+
+class DatabaseAPI20Test(unittest.TestCase):
+ ''' Test a database self.driver for DB API 2.0 compatibility.
+ This implementation tests Gadfly, but the TestCase
+ is structured so that other self.drivers can subclass this
+ test case to ensure compiliance with the DB-API. It is
+ expected that this TestCase may be expanded in the future
+ if ambiguities or edge conditions are discovered.
+
+ The 'Optional Extensions' are not yet being tested.
+
+ self.drivers should subclass this test, overriding setUp, tearDown,
+ self.driver, connect_args and connect_kw_args. Class specification
+ should be as follows:
+
+ from . import dbapi20
+ class mytest(dbapi20.DatabaseAPI20Test):
+ [...]
+
+ Don't 'from .dbapi20 import DatabaseAPI20Test', or you will
+ confuse the unit tester - just 'from . import dbapi20'.
+ '''
+
+ # The self.driver module. This should be the module where the 'connect'
+ # method is to be found
+ driver: Any = None
+ connect_args = () # List of arguments to pass to connect
+ connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect
+ table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+
+ ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
+ ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
+ xddl1 = 'drop table %sbooze' % table_prefix
+ xddl2 = 'drop table %sbarflys' % table_prefix
+
+ lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+
+ # Some drivers may need to override these helpers, for example adding
+ # a 'commit' after the execute.
+ def executeDDL1(self,cursor):
+ cursor.execute(self.ddl1)
+
+ def executeDDL2(self,cursor):
+ cursor.execute(self.ddl2)
+
+ def setUp(self):
+ ''' self.drivers should override this method to perform required setup
+ if any is necessary, such as creating the database.
+ '''
+ pass
+
+ def tearDown(self):
+ ''' self.drivers should override this method to perform required cleanup
+ if any is necessary, such as deleting the test database.
+ The default drops the tables that may be created.
+ '''
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ for ddl in (self.xddl1,self.xddl2):
+ try:
+ cur.execute(ddl)
+ con.commit()
+ except self.driver.Error:
+ # Assume table didn't exist. Other tests will check if
+ # execute is busted.
+ pass
+ finally:
+ con.close()
+
+ def _connect(self):
+ try:
+ return self.driver.connect(
+ *self.connect_args,**self.connect_kw_args
+ )
+ except AttributeError:
+ self.fail("No connect method found in self.driver module")
+
+ def test_connect(self):
+ con = self._connect()
+ con.close()
+
+ def test_apilevel(self):
+ try:
+ # Must exist
+ apilevel = self.driver.apilevel
+ # Must equal 2.0
+ self.assertEqual(apilevel,'2.0')
+ except AttributeError:
+ self.fail("Driver doesn't define apilevel")
+
+ def test_threadsafety(self):
+ try:
+ # Must exist
+ threadsafety = self.driver.threadsafety
+ # Must be a valid value
+ self.failUnless(threadsafety in (0,1,2,3))
+ except AttributeError:
+ self.fail("Driver doesn't define threadsafety")
+
+ def test_paramstyle(self):
+ try:
+ # Must exist
+ paramstyle = self.driver.paramstyle
+ # Must be a valid value
+ self.failUnless(paramstyle in (
+ 'qmark','numeric','named','format','pyformat'
+ ))
+ except AttributeError:
+ self.fail("Driver doesn't define paramstyle")
+
+ def test_Exceptions(self):
+ # Make sure required exceptions exist, and are in the
+ # defined hierarchy.
+ if sys.version[0] == '3': #under Python 3 StardardError no longer exists
+ self.failUnless(issubclass(self.driver.Warning,Exception))
+ self.failUnless(issubclass(self.driver.Error,Exception))
+ else:
+ self.failUnless(issubclass(self.driver.Warning,StandardError)) # type: ignore[name-defined]
+ self.failUnless(issubclass(self.driver.Error,StandardError)) # type: ignore[name-defined]
+
+ self.failUnless(
+ issubclass(self.driver.InterfaceError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.DatabaseError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.OperationalError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.IntegrityError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.InternalError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.ProgrammingError,self.driver.Error)
+ )
+ self.failUnless(
+ issubclass(self.driver.NotSupportedError,self.driver.Error)
+ )
+
+ def test_ExceptionsAsConnectionAttributes(self):
+ # OPTIONAL EXTENSION
+ # Test for the optional DB API 2.0 extension, where the exceptions
+ # are exposed as attributes on the Connection object
+ # I figure this optional extension will be implemented by any
+ # driver author who is using this test suite, so it is enabled
+ # by default.
+ con = self._connect()
+ drv = self.driver
+ self.failUnless(con.Warning is drv.Warning)
+ self.failUnless(con.Error is drv.Error)
+ self.failUnless(con.InterfaceError is drv.InterfaceError)
+ self.failUnless(con.DatabaseError is drv.DatabaseError)
+ self.failUnless(con.OperationalError is drv.OperationalError)
+ self.failUnless(con.IntegrityError is drv.IntegrityError)
+ self.failUnless(con.InternalError is drv.InternalError)
+ self.failUnless(con.ProgrammingError is drv.ProgrammingError)
+ self.failUnless(con.NotSupportedError is drv.NotSupportedError)
+ con.close()
+
+
+ def test_commit(self):
+ con = self._connect()
+ try:
+ # Commit must work, even if it doesn't do anything
+ con.commit()
+ finally:
+ con.close()
+
+ def test_rollback(self):
+ con = self._connect()
+ # If rollback is defined, it should either work or throw
+ # the documented exception
+ if hasattr(con,'rollback'):
+ try:
+ con.rollback()
+ except self.driver.NotSupportedError:
+ pass
+ con.close()
+
+ def test_cursor(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ def test_cursor_isolation(self):
+ con = self._connect()
+ try:
+ # Make sure cursors created from the same connection have
+ # the documented transaction isolation level
+ cur1 = con.cursor()
+ cur2 = con.cursor()
+ self.executeDDL1(cur1)
+ cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ cur2.execute("select name from %sbooze" % self.table_prefix)
+ booze = cur2.fetchall()
+ self.assertEqual(len(booze),1)
+ self.assertEqual(len(booze[0]),1)
+ self.assertEqual(booze[0][0],'Victoria Bitter')
+ finally:
+ con.close()
+
+ def test_description(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description should be none after executing a '
+ 'statement that can return no rows (such as DDL)'
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(len(cur.description),1,
+ 'cursor.description describes too many columns'
+ )
+ self.assertEqual(len(cur.description[0]),7,
+ 'cursor.description[x] tuples must have 7 elements'
+ )
+ self.assertEqual(cur.description[0][0].lower(),'name',
+ 'cursor.description[x][0] must return column name'
+ )
+ self.assertEqual(cur.description[0][1],self.driver.STRING,
+ 'cursor.description[x][1] must return column type. Got %r'
+ % cur.description[0][1]
+ )
+
+ # Make sure self.description gets reset
+ self.executeDDL2(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description not being set to None when executing '
+ 'no-result statements (eg. DDL)'
+ )
+ finally:
+ con.close()
+
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount should be -1 after executing no-result '
+ 'statements'
+ )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.failUnless(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number or rows inserted, or '
+ 'set to -1 after executing an insert statement'
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.failUnless(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount not being reset to -1 after executing '
+ 'no-result statements'
+ )
+ finally:
+ con.close()
+
+ lower_func = 'lower'
+ def test_callproc(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if self.lower_func and hasattr(cur,'callproc'):
+ r = cur.callproc(self.lower_func,('FOO',))
+ self.assertEqual(len(r),1)
+ self.assertEqual(r[0],'FOO')
+ r = cur.fetchall()
+ self.assertEqual(len(r),1,'callproc produced no result set')
+ self.assertEqual(len(r[0]),1,
+ 'callproc produced invalid result set'
+ )
+ self.assertEqual(r[0][0],'foo',
+ 'callproc produced invalid results'
+ )
+ finally:
+ con.close()
+
+ def test_close(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ # cursor.execute should raise an Error if called after connection
+ # closed
+ self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+
+ # connection.commit should raise an Error if called after connection'
+ # closed.'
+ self.assertRaises(self.driver.Error,con.commit)
+
+ # connection.close should raise an Error if called more than once
+ # Issue discussed on DB-SIG: consensus seem that close() should not
+ # raised if called on closed objects. Issue reported back to Stuart.
+ # self.assertRaises(self.driver.Error,con.close)
+
+ def test_execute(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self._paraminsert(cur)
+ finally:
+ con.close()
+
+ def _paraminsert(self,cur):
+ self.executeDDL1(cur)
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.failUnless(cur.rowcount in (-1,1))
+
+ if self.driver.paramstyle == 'qmark':
+ cur.execute(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.execute(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.execute(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.execute(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.execute(
+ 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ else:
+ self.fail('Invalid paramstyle')
+ self.failUnless(cur.rowcount in (-1,1))
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Cooper's",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+ self.assertEqual(beers[1],"Victoria Bitter",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+
+ def test_executemany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ largs = [ ("Cooper's",) , ("Boag's",) ]
+ margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
+ if self.driver.paramstyle == 'qmark':
+ cur.executemany(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.executemany(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.executemany(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ margs
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.executemany(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.executemany(
+ 'insert into %sbooze values (%%(beer)s)' % (
+ self.table_prefix
+ ),
+ margs
+ )
+ else:
+ self.fail('Unknown paramstyle')
+ self.failUnless(cur.rowcount in (-1,2),
+ 'insert using cursor.executemany set cursor.rowcount to '
+ 'incorrect value %r' % cur.rowcount
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,
+ 'cursor.fetchall retrieved incorrect number of rows'
+ )
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
+ self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannot return rows
+ self.executeDDL1(cur)
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.failUnless(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if no more rows available'
+ )
+ self.failUnless(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ samples = [
+ 'Carlton Cold',
+ 'Carlton Draft',
+ 'Mountain Goat',
+ 'Redback',
+ 'Victoria Bitter',
+ 'XXXX'
+ ]
+
+ def _populate(self):
+ ''' Return a list of sql commands to setup the DB for the fetch
+ tests.
+ '''
+ populate = [
+ "insert into %sbooze values ('%s')" % (self.table_prefix,s)
+ for s in self.samples
+ ]
+ return populate
+
+ def test_fetchmany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchmany should raise an Error if called without
+ #issuing a query
+ self.assertRaises(self.driver.Error,cur.fetchmany,4)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchmany retrieved incorrect number of rows, '
+ 'default of arraysize is one.'
+ )
+ cur.arraysize=10
+ r = cur.fetchmany(3) # Should get 3 rows
+ self.assertEqual(len(r),3,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should get 2 more
+ self.assertEqual(len(r),2,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should be an empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence after '
+ 'results are exhausted'
+ )
+ self.failUnless(cur.rowcount in (-1,6))
+
+ # Same as above, using cursor.arraysize
+ cur.arraysize=4
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany() # Should get 4 rows
+ self.assertEqual(len(r),4,
+ 'cursor.arraysize not being honoured by fetchmany'
+ )
+ r = cur.fetchmany() # Should get 2 more
+ self.assertEqual(len(r),2)
+ r = cur.fetchmany() # Should be an empty sequence
+ self.assertEqual(len(r),0)
+ self.failUnless(cur.rowcount in (-1,6))
+
+ cur.arraysize=6
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchmany() # Should get all rows
+ self.failUnless(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows),6)
+ self.assertEqual(len(rows),6)
+ rows = [r[0] for r in rows]
+ rows.sort()
+
+ # Make sure we get the right data back out
+ for i in range(0,6):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved by cursor.fetchmany'
+ )
+
+ rows = cur.fetchmany() # Should return an empty list
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'called after the whole result set has been fetched'
+ )
+ self.failUnless(cur.rowcount in (-1,6))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ r = cur.fetchmany() # Should get empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'query retrieved no rows'
+ )
+ self.failUnless(cur.rowcount in (-1,0))
+
+ finally:
+ con.close()
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+ self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.failUnless(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_mixedfetch(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows1 = cur.fetchone()
+ rows23 = cur.fetchmany(2)
+ rows4 = cur.fetchone()
+ rows56 = cur.fetchall()
+ self.failUnless(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows23),2,
+ 'fetchmany returned incorrect number of rows'
+ )
+ self.assertEqual(len(rows56),2,
+ 'fetchall returned incorrect number of rows'
+ )
+
+ rows = [rows1[0]]
+ rows.extend([rows23[0][0],rows23[1][0]])
+ rows.append(rows4[0])
+ rows.extend([rows56[0][0],rows56[1][0]])
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved or inserted'
+ )
+ finally:
+ con.close()
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ raise NotImplementedError('Helper not implemented')
+ #sql="""
+ # create procedure deleteme as
+ # begin
+ # select count(*) from booze
+ # select name from booze
+ # end
+ #"""
+ #cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ raise NotImplementedError('Helper not implemented')
+ #cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ assert s is None, 'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+ def test_arraysize(self):
+ # Not much here - rest of the tests for this are in test_fetchmany
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.failUnless(hasattr(cur,'arraysize'),
+ 'cursor.arraysize must be defined'
+ )
+ finally:
+ con.close()
+
+ def test_setinputsizes(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setinputsizes( (25,) )
+ self._paraminsert(cur) # Make sure cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize_basic(self):
+ # Basic test is to make sure setoutputsize doesn't blow up
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setoutputsize(1000)
+ cur.setoutputsize(2000,0)
+ self._paraminsert(cur) # Make sure the cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize(self):
+ # Real test for setoutputsize is driver dependent
+ raise NotImplementedError('Driver needed to override this test')
+
+ def test_None(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchall()
+ self.assertEqual(len(r),1)
+ self.assertEqual(len(r[0]),1)
+ self.assertEqual(r[0][0],None,'NULL value not returned as None')
+ finally:
+ con.close()
+
+ def test_Date(self):
+ d1 = self.driver.Date(2002,12,25)
+ d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(d1),str(d2))
+
+ def test_Time(self):
+ t1 = self.driver.Time(13,45,30)
+ t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Timestamp(self):
+ t1 = self.driver.Timestamp(2002,12,25,13,45,30)
+ t2 = self.driver.TimestampFromTicks(
+ time.mktime((2002,12,25,13,45,30,0,0,0))
+ )
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Binary(self):
+ b = self.driver.Binary(b'Something')
+ b = self.driver.Binary(b'')
+
+ def test_STRING(self):
+ self.failUnless(hasattr(self.driver,'STRING'),
+ 'module.STRING must be defined'
+ )
+
+ def test_BINARY(self):
+ self.failUnless(hasattr(self.driver,'BINARY'),
+ 'module.BINARY must be defined.'
+ )
+
+ def test_NUMBER(self):
+ self.failUnless(hasattr(self.driver,'NUMBER'),
+ 'module.NUMBER must be defined.'
+ )
+
+ def test_DATETIME(self):
+ self.failUnless(hasattr(self.driver,'DATETIME'),
+ 'module.DATETIME must be defined.'
+ )
+
+ def test_ROWID(self):
+ self.failUnless(hasattr(self.driver,'ROWID'),
+ 'module.ROWID must be defined.'
+ )
+# fmt: on
diff --git a/tests/dbapi20_tpc.py b/tests/dbapi20_tpc.py
new file mode 100644
index 0000000..7254294
--- /dev/null
+++ b/tests/dbapi20_tpc.py
@@ -0,0 +1,151 @@
+# flake8: noqa
+# fmt: off
+
+""" Python DB API 2.0 driver Two Phase Commit compliance test suite.
+
+"""
+
+import unittest
+from typing import Any
+
+
+class TwoPhaseCommitTests(unittest.TestCase):
+
+ driver: Any = None
+
+ def connect(self):
+ """Make a database connection."""
+ raise NotImplementedError
+
+ _last_id = 0
+ _global_id_prefix = "dbapi20_tpc:"
+
+ def make_xid(self, con):
+ id = TwoPhaseCommitTests._last_id
+ TwoPhaseCommitTests._last_id += 1
+ return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier")
+
+ def test_xid(self):
+ con = self.connect()
+ try:
+ try:
+ xid = con.xid(42, "global", "bqual")
+ except self.driver.NotSupportedError:
+ self.fail("Driver does not support transaction IDs.")
+
+ self.assertEquals(xid[0], 42)
+ self.assertEquals(xid[1], "global")
+ self.assertEquals(xid[2], "bqual")
+
+ # Try some extremes for the transaction ID:
+ xid = con.xid(0, "", "")
+ self.assertEquals(tuple(xid), (0, "", ""))
+ xid = con.xid(0x7fffffff, "a" * 64, "b" * 64)
+ self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64))
+ finally:
+ con.close()
+
+ def test_tpc_begin(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ try:
+ con.tpc_begin(xid)
+ except self.driver.NotSupportedError:
+ self.fail("Driver does not support tpc_begin()")
+ finally:
+ con.close()
+
+ def test_tpc_commit_without_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_commit()
+ finally:
+ con.close()
+
+ def test_tpc_rollback_without_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_rollback()
+ finally:
+ con.close()
+
+ def test_tpc_commit_with_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_prepare()
+ con.tpc_commit()
+ finally:
+ con.close()
+
+ def test_tpc_rollback_with_prepare(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ con.tpc_prepare()
+ con.tpc_rollback()
+ finally:
+ con.close()
+
+ def test_tpc_begin_in_transaction_fails(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ self.assertRaises(self.driver.ProgrammingError,
+ con.tpc_begin, xid)
+ finally:
+ con.close()
+
+ def test_tpc_begin_in_tpc_transaction_fails(self):
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+
+ cursor = con.cursor()
+ cursor.execute("SELECT 1")
+ self.assertRaises(self.driver.ProgrammingError,
+ con.tpc_begin, xid)
+ finally:
+ con.close()
+
+ def test_commit_in_tpc_fails(self):
+ # calling commit() within a TPC transaction fails with
+ # ProgrammingError.
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+
+ self.assertRaises(self.driver.ProgrammingError, con.commit)
+ finally:
+ con.close()
+
+ def test_rollback_in_tpc_fails(self):
+ # calling rollback() within a TPC transaction fails with
+ # ProgrammingError.
+ con = self.connect()
+ try:
+ xid = self.make_xid(con)
+ con.tpc_begin(xid)
+
+ self.assertRaises(self.driver.ProgrammingError, con.rollback)
+ finally:
+ con.close()
diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py
new file mode 100644
index 0000000..88ab504
--- /dev/null
+++ b/tests/fix_crdb.py
@@ -0,0 +1,131 @@
+from typing import Optional
+
+import pytest
+
+from .utils import VersionCheck
+from psycopg.crdb import CrdbConnection
+
+
+def pytest_configure(config):
+ # register libpq marker
+ config.addinivalue_line(
+ "markers",
+ "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB"
+ " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')",
+ )
+ config.addinivalue_line(
+ "markers",
+ "crdb_skip(reason): skip the test for known CockroachDB reasons",
+ )
+
+
+def check_crdb_version(got, mark):
+ if mark.name == "crdb":
+ assert len(mark.args) <= 1
+ assert not (set(mark.kwargs) - {"reason"})
+ spec = mark.args[0] if mark.args else "only"
+ reason = mark.kwargs.get("reason")
+ elif mark.name == "crdb_skip":
+ assert len(mark.args) == 1
+ assert not mark.kwargs
+ reason = mark.args[0]
+ assert reason in _crdb_reasons, reason
+ spec = _crdb_reason_version.get(reason, "skip")
+ else:
+ assert False, mark.name
+
+ pred = VersionCheck.parse(spec)
+ pred.whose = "CockroachDB"
+
+ msg = pred.get_skip_message(got)
+ if not msg:
+ return None
+
+ reason = crdb_skip_message(reason)
+ if reason:
+ msg = f"{msg}: {reason}"
+
+ return msg
+
+
+# Utility functions which can be imported in the test suite
+
+is_crdb = CrdbConnection.is_crdb
+
+
+def crdb_skip_message(reason: Optional[str]) -> str:
+ msg = ""
+ if reason:
+ msg = reason
+ if _crdb_reasons.get(reason):
+ url = (
+ "https://github.com/cockroachdb/cockroach/"
+ f"issues/{_crdb_reasons[reason]}"
+ )
+ msg = f"{msg} ({url})"
+
+ return msg
+
+
+def skip_crdb(*args, reason=None):
+ return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason))
+
+
+def crdb_encoding(*args):
+ """Mark tests that fail on CockroachDB because of missing encodings"""
+ return skip_crdb(*args, reason="encoding")
+
+
+def crdb_time_precision(*args):
+ """Mark tests that fail on CockroachDB because time doesn't support precision"""
+ return skip_crdb(*args, reason="time precision")
+
+
+def crdb_scs_off(*args):
+ return skip_crdb(*args, reason="standard_conforming_strings=off")
+
+
+# mapping from reason description to ticket number
+_crdb_reasons = {
+ "2-phase commit": 22329,
+ "backend pid": 35897,
+ "batch statements": 44803,
+ "begin_read_only": 87012,
+ "binary decimal": 82492,
+ "cancel": 41335,
+ "cast adds tz": 51692,
+ "cidr": 18846,
+ "composite": 27792,
+ "copy array": 82792,
+ "copy canceled": 81559,
+ "copy": 41608,
+ "cursor invalid name": 84261,
+ "cursor with hold": 77101,
+ "deferrable": 48307,
+ "do": 17511,
+ "encoding": 35882,
+ "geometric types": 21286,
+ "hstore": 41284,
+ "infinity date": 41564,
+ "interval style": 35807,
+ "json array": 23468,
+ "large objects": 243,
+ "negative interval": 81577,
+ "nested array": 32552,
+ "no col query": None,
+ "notify": 41522,
+ "password_encryption": 42519,
+ "pg_terminate_backend": 35897,
+ "range": 41282,
+ "scroll cursor": 77102,
+ "server-side cursor": 41412,
+ "severity_nonlocalized": 81794,
+ "stored procedure": 1751,
+}
+
+_crdb_reason_version = {
+ "backend pid": "skip < 22",
+ "cancel": "skip < 22",
+ "server-side cursor": "skip < 22.1.3",
+ "severity_nonlocalized": "skip < 22.1.3",
+}
diff --git a/tests/fix_db.py b/tests/fix_db.py
new file mode 100644
index 0000000..3a37aa1
--- /dev/null
+++ b/tests/fix_db.py
@@ -0,0 +1,358 @@
+import io
+import os
+import sys
+import pytest
+import logging
+from contextlib import contextmanager
+from typing import Optional
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg._compat import cache
+from psycopg.pq._debug import PGconnDebug
+
+from .utils import check_postgres_version
+
+# Set by warm_up_database() the first time the dsn fixture is used
+pg_version: int
+crdb_version: Optional[int]
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--test-dsn",
+ metavar="DSN",
+ default=os.environ.get("PSYCOPG_TEST_DSN"),
+ help=(
+ "Connection string to run database tests requiring a connection"
+ " [you can also use the PSYCOPG_TEST_DSN env var]."
+ ),
+ )
+ parser.addoption(
+ "--pq-trace",
+ metavar="{TRACEFILE,STDERR}",
+ default=None,
+ help="Generate a libpq trace to TRACEFILE or STDERR.",
+ )
+ parser.addoption(
+ "--pq-debug",
+ action="store_true",
+ default=False,
+ help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)",
+ )
+
+
+def pytest_report_header(config):
+ dsn = config.getoption("--test-dsn")
+ if dsn is None:
+ return []
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ server_version = conn.execute("select version()").fetchall()[0][0]
+ except Exception as ex:
+ server_version = f"unknown ({ex})"
+
+ return [
+ f"Server version: {server_version}",
+ ]
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ for name in item.fixturenames:
+ if name in ("pipeline", "apipeline"):
+ item.add_marker(pytest.mark.pipeline)
+ break
+
+
+def pytest_runtest_setup(item):
+ for m in item.iter_markers(name="pipeline"):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+
+
+def pytest_configure(config):
+ # register pg marker
+ markers = [
+ "pg(version_expr): run the test only with matching server version"
+ " (e.g. '>= 10', '< 9.6')",
+ "pipeline: the test runs with connection in pipeline mode",
+ ]
+ for marker in markers:
+ config.addinivalue_line("markers", marker)
+
+
+@pytest.fixture(scope="session")
+def session_dsn(request):
+ """
+ Return the dsn used to connect to the `--test-dsn` database (session-wide).
+ """
+ dsn = request.config.getoption("--test-dsn")
+ if dsn is None:
+ pytest.skip("skipping test as no --test-dsn")
+
+ warm_up_database(dsn)
+ return dsn
+
+
+@pytest.fixture
+def dsn(session_dsn, request):
+ """Return the dsn used to connect to the `--test-dsn` database."""
+ check_connection_version(request.node)
+ return session_dsn
+
+
+@pytest.fixture(scope="session")
+def tracefile(request):
+ """Open and yield a file for libpq client/server communication traces if
+ --pq-tracefile option is set.
+ """
+ tracefile = request.config.getoption("--pq-trace")
+ if not tracefile:
+ yield None
+ return
+
+ if tracefile.lower() == "stderr":
+ try:
+ sys.stderr.fileno()
+ except io.UnsupportedOperation:
+ raise pytest.UsageError(
+ "cannot use stderr for --pq-trace (in-memory file?)"
+ ) from None
+
+ yield sys.stderr
+ return
+
+ with open(tracefile, "w") as f:
+ yield f
+
+
+@contextmanager
+def maybe_trace(pgconn, tracefile, function):
+ """Handle libpq client/server communication traces for a single test
+ function.
+ """
+ if tracefile is None:
+ yield None
+ return
+
+ if tracefile != sys.stderr:
+ title = f" {function.__module__}::{function.__qualname__} ".center(80, "=")
+ tracefile.write(title + "\n")
+ tracefile.flush()
+
+ pgconn.trace(tracefile.fileno())
+ try:
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ except psycopg.NotSupportedError:
+ pass
+ try:
+ yield None
+ finally:
+ pgconn.untrace()
+
+
+@pytest.fixture(autouse=True)
+def pgconn_debug(request):
+ if not request.config.getoption("--pq-debug"):
+ return
+ if pq.__impl__ != "python":
+ raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug")
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+ pq.PGconn = PGconnDebug
+
+
+@pytest.fixture
+def pgconn(dsn, request, tracefile):
+ """Return a PGconn connection open to `--test-dsn`."""
+ check_connection_version(request.node)
+
+ conn = pq.PGconn.connect(dsn.encode())
+ if conn.status != pq.ConnStatus.OK:
+ pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
+
+ with maybe_trace(conn, tracefile, request.function):
+ yield conn
+
+ conn.finish()
+
+
+@pytest.fixture
+def conn(conn_cls, dsn, request, tracefile):
+ """Return a `Connection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = conn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+def pipeline(request, conn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ with conn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture
+async def aconn(dsn, aconn_cls, request, tracefile):
+ """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+ check_connection_version(request.node)
+
+ conn = await aconn_cls.connect(dsn)
+ with maybe_trace(conn.pgconn, tracefile, request.function):
+ yield conn
+ await conn.close()
+
+
+@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"])
+async def apipeline(request, aconn):
+ if request.param:
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ async with aconn.pipeline() as p:
+ yield p
+ return
+ else:
+ yield None
+
+
+@pytest.fixture(scope="session")
+def conn_cls(session_dsn):
+ cls = psycopg.Connection
+ if crdb_version:
+ from psycopg.crdb import CrdbConnection
+
+ cls = CrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def aconn_cls(session_dsn):
+ cls = psycopg.AsyncConnection
+ if crdb_version:
+ from psycopg.crdb import AsyncCrdbConnection
+
+ cls = AsyncCrdbConnection
+
+ return cls
+
+
+@pytest.fixture(scope="session")
+def svcconn(conn_cls, session_dsn):
+ """
+ Return a session `Connection` connected to the ``--test-dsn`` database.
+ """
+ conn = conn_cls.connect(session_dsn, autocommit=True)
+ yield conn
+ conn.close()
+
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(conn, monkeypatch)
+
+
+@pytest.fixture
+def acommands(aconn, monkeypatch):
+ """The list of commands issued internally by the test async connection."""
+ yield patch_exec(aconn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+ """Helper to implement the commands fixture both sync and async."""
+ _orig_exec_command = conn._exec_command
+ L = ListPopAll()
+
+ def _exec_command(command, *args, **kwargs):
+ cmdcopy = command
+ if isinstance(cmdcopy, bytes):
+ cmdcopy = cmdcopy.decode(conn.info.encoding)
+ elif isinstance(cmdcopy, sql.Composable):
+ cmdcopy = cmdcopy.as_string(conn)
+
+ L.append(cmdcopy)
+ return _orig_exec_command(command, *args, **kwargs)
+
+ monkeypatch.setattr(conn, "_exec_command", _exec_command)
+ return L
+
+
+class ListPopAll(list): # type: ignore[type-arg]
+ """A list, with a popall() method."""
+
+ def popall(self):
+ out = self[:]
+ del self[:]
+ return out
+
+
+def check_connection_version(node):
+ try:
+ pg_version
+ except NameError:
+ # First connection creation failed. Let the tests fail.
+ pytest.fail("server version not available")
+
+ for mark in node.iter_markers():
+ if mark.name == "pg":
+ assert len(mark.args) == 1
+ msg = check_postgres_version(pg_version, mark.args[0])
+ if msg:
+ pytest.skip(msg)
+
+ elif mark.name in ("crdb", "crdb_skip"):
+ from .fix_crdb import check_crdb_version
+
+ msg = check_crdb_version(crdb_version, mark)
+ if msg:
+ pytest.skip(msg)
+
+
+@pytest.fixture
+def hstore(svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists hstore")
+ except psycopg.Error as e:
+ pytest.skip(str(e))
+
+
+@cache
+def warm_up_database(dsn: str) -> None:
+ """Connect to the database before returning a connection.
+
+ In the CI sometimes, the first test fails with a timeout, probably because
+ the server hasn't started yet. Absorb the delay before the test.
+
+ In case of error, abort the test run entirely, to avoid failing downstream
+ hundreds of times.
+ """
+ global pg_version, crdb_version
+
+ try:
+ with psycopg.connect(dsn, connect_timeout=10) as conn:
+ conn.execute("select 1")
+
+ pg_version = conn.info.server_version
+
+ crdb_version = None
+ param = conn.info.parameter_status("crdb_version")
+ if param:
+ from psycopg.crdb import CrdbConnectionInfo
+
+ crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
+ except Exception as exc:
+ pytest.exit(f"failed to connect to the test database: {exc}")
diff --git a/tests/fix_faker.py b/tests/fix_faker.py
new file mode 100644
index 0000000..5289d8f
--- /dev/null
+++ b/tests/fix_faker.py
@@ -0,0 +1,868 @@
+import datetime as dt
+import importlib
+import ipaddress
+from math import isnan
+from uuid import UUID
+from random import choice, random, randrange
+from typing import Any, List, Set, Tuple, Union
+from decimal import Decimal
+from contextlib import contextmanager, asynccontextmanager
+
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.adapt import PyFormat
+from psycopg._compat import Deque
+from psycopg.types.range import Range
+from psycopg.types.json import Json, Jsonb
+from psycopg.types.numeric import Int4, Int8
+from psycopg.types.multirange import Multirange
+
+
+@pytest.fixture
+def faker(conn):
+ return Faker(conn)
+
+
+class Faker:
+ """
+ An object to generate random records.
+ """
+
+ json_max_level = 3
+ json_max_length = 10
+ str_max_length = 100
+ list_max_length = 20
+ tuple_max_length = 15
+
+ def __init__(self, connection):
+ self.conn = connection
+ self.format = PyFormat.BINARY
+ self.records = []
+
+ self._schema = None
+ self._types = None
+ self._types_names = None
+ self._makers = {}
+ self.table_name = sql.Identifier("fake_table")
+
+ @property
+ def schema(self):
+ if not self._schema:
+ self.schema = self.choose_schema()
+ return self._schema
+
+ @schema.setter
+ def schema(self, schema):
+ self._schema = schema
+ self._types_names = None
+
+ @property
+ def fields_names(self):
+ return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
+
+ @property
+ def types(self):
+ if not self._types:
+
+ def key(cls: type) -> str:
+ return cls.__name__
+
+ self._types = sorted(self.get_supported_types(), key=key)
+ return self._types
+
+ @property
+ def types_names_sql(self):
+ if self._types_names:
+ return self._types_names
+
+ record = self.make_record(nulls=0)
+ tx = psycopg.adapt.Transformer(self.conn)
+ types = [
+ self._get_type_name(tx, schema, value)
+ for schema, value in zip(self.schema, record)
+ ]
+ self._types_names = types
+ return types
+
+ @property
+ def types_names(self):
+ types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql]
+ return types
+
+ def _get_type_name(self, tx, schema, value):
+ # Special case it as it is passed as unknown so is returned as text
+ if schema == (list, str):
+ return sql.SQL("text[]")
+
+ registry = self.conn.adapters.types
+ dumper = tx.get_dumper(value, self.format)
+ dumper.dump(value) # load the oid if it's dynamic (e.g. array)
+ info = registry.get(dumper.oid) or registry.get("text")
+ if dumper.oid == info.array_oid:
+ return sql.SQL("{}[]").format(sql.Identifier(info.name))
+ else:
+ return sql.Identifier(info.name)
+
+ @property
+ def drop_stmt(self):
+ return sql.SQL("drop table if exists {}").format(self.table_name)
+
+ @property
+ def create_stmt(self):
+ field_values = []
+ for name, type in zip(self.fields_names, self.types_names_sql):
+ field_values.append(sql.SQL("{} {}").format(name, type))
+
+ fields = sql.SQL(", ").join(field_values)
+ return sql.SQL("create table {table} (id serial primary key, {fields})").format(
+ table=self.table_name, fields=fields
+ )
+
+ @property
+ def insert_stmt(self):
+ phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))]
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name,
+ sql.SQL(", ").join(self.fields_names),
+ sql.SQL(", ").join(phs),
+ )
+
+ @property
+ def select_stmt(self):
+ fields = sql.SQL(", ").join(self.fields_names)
+ return sql.SQL("select {} from {} order by id").format(fields, self.table_name)
+
+ @contextmanager
+ def find_insert_problem(self, conn):
+ """Context manager to help finding a problematic value."""
+ try:
+ with conn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ cur = conn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ cur.execute(self.drop_stmt)
+ cur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ cur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ @asynccontextmanager
+ async def find_insert_problem_async(self, aconn):
+ try:
+ async with aconn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ acur = aconn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ await acur.execute(self.drop_stmt)
+ await acur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ await acur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ def _insert_field_stmt(self, i):
+ ph = sql.Placeholder(format=self.format)
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name, self.fields_names[i], ph
+ )
+
+ def choose_schema(self, ncols=20):
+ schema: List[Union[Tuple[type, ...], type]] = []
+ while len(schema) < ncols:
+ s = self.make_schema(choice(self.types))
+ if s is not None:
+ schema.append(s)
+ self.schema = schema
+ return schema
+
+ def make_records(self, nrecords):
+ self.records = [self.make_record(nulls=0.05) for i in range(nrecords)]
+
+ def make_record(self, nulls=0):
+ if not nulls:
+ return tuple(self.example(spec) for spec in self.schema)
+ else:
+ return tuple(
+ self.make(spec) if random() > nulls else None for spec in self.schema
+ )
+
+ def assert_record(self, got, want):
+ for spec, g, w in zip(self.schema, got, want):
+ if g is None and w is None:
+ continue
+ m = self.get_matcher(spec)
+ m(spec, g, w)
+
+ def get_supported_types(self) -> Set[type]:
+ dumpers = self.conn.adapters._dumpers[self.format]
+ rv = set()
+ for cls in dumpers.keys():
+ if isinstance(cls, str):
+ cls = deep_import(cls)
+ if issubclass(cls, Multirange) and self.conn.info.server_version < 140000:
+ continue
+
+ rv.add(cls)
+
+ # check all the types are handled
+ for cls in rv:
+ self.get_maker(cls)
+
+ return rv
+
+ def make_schema(self, cls: type) -> Union[Tuple[type, ...], type, None]:
+ """Create a schema spec from a Python type.
+
+ A schema specifies what Postgres type to generate when a Python type
+ maps to more than one (e.g. tuple -> composite, list -> array[],
+ datetime -> timestamp[tz]).
+
+ A schema for a type is represented by a tuple (type, ...) which the
+ matching make_*() method can interpret, or just type if the type
+ doesn't require further specification.
+
+ A `None` means that the type is not supported.
+ """
+ meth = self._get_method("schema", cls)
+ return meth(cls) if meth else cls
+
+ def get_maker(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+
+ try:
+ return self._makers[cls]
+ except KeyError:
+ pass
+
+ meth = self._get_method("make", cls)
+ if meth:
+ self._makers[cls] = meth
+ return meth
+ else:
+ raise NotImplementedError(f"cannot make fake objects of class {cls}")
+
+ def get_matcher(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("match", cls)
+ return meth if meth else self.match_any
+
+ def _get_method(self, prefix, cls):
+ name = cls.__name__
+ if cls.__module__ != "builtins":
+ name = f"{cls.__module__}.{name}"
+
+ parts = name.split(".")
+ for i in range(len(parts)):
+ mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
+ meth = getattr(self, mname, None)
+ if meth:
+ return meth
+
+ return None
+
+ def make(self, spec):
+ # spec can be a type or a tuple (type, options)
+ return self.get_maker(spec)(spec)
+
+ def example(self, spec):
+ # A good representative of the object - no degenerate case
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("example", cls)
+ if meth:
+ return meth(spec)
+ else:
+ return self.make(spec)
+
+ def match_any(self, spec, got, want):
+ assert got == want
+
+ # methods to generate samples of specific types
+
+ def make_Binary(self, spec):
+ return self.make_bytes(spec)
+
+ def match_Binary(self, spec, got, want):
+ return want.obj == got
+
+ def make_bool(self, spec):
+ return choice((True, False))
+
+ def make_bytearray(self, spec):
+ return self.make_bytes(spec)
+
+ def make_bytes(self, spec):
+ length = randrange(self.str_max_length)
+ return spec(bytes([randrange(256) for i in range(length)]))
+
+ def make_date(self, spec):
+ day = randrange(dt.date.max.toordinal())
+ return dt.date.fromordinal(day + 1)
+
+ def schema_datetime(self, cls):
+ return self.schema_time(cls)
+
+ def make_datetime(self, spec):
+ # Add a day because with timezone we might go BC
+ dtmin = dt.datetime.min + dt.timedelta(days=1)
+ delta = dt.datetime.max - dtmin
+ micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
+ rv = dtmin + dt.timedelta(microseconds=micros)
+ if spec[1]:
+ rv = rv.replace(tzinfo=self._make_tz(spec))
+ return rv
+
+ def match_datetime(self, spec, got, want):
+ # Comparisons with different timezones is unreliable: certain pairs
+ # are reported different but their delta is 0
+ # https://bugs.python.org/issue45347
+ assert not (got - want)
+
+ def make_Decimal(self, spec):
+ if random() >= 0.99:
+ return Decimal(choice(self._decimal_special_values()))
+
+ sign = choice("+-")
+ num = choice(["0.zd", "d", "d.d"])
+ while "z" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace("z", "0" * ndigits, 1)
+ while "d" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace(
+ "d", "".join([str(randrange(10)) for i in range(ndigits)]), 1
+ )
+ expsign = choice(["e+", "e-", ""])
+ exp = randrange(20) if expsign else ""
+ rv = Decimal(f"{sign}{num}{expsign}{exp}")
+ return rv
+
+ def match_Decimal(self, spec, got, want):
+ if got is not None and got.is_nan():
+ assert want.is_nan()
+ else:
+ assert got == want
+
+ def _decimal_special_values(self):
+ values = ["NaN", "sNaN"]
+
+ if self.conn.info.vendor == "PostgreSQL":
+ if self.conn.info.server_version >= 140000:
+ values.extend(["Inf", "-Inf"])
+ elif self.conn.info.vendor == "CockroachDB":
+ if self.conn.info.server_version >= 220100:
+ values.extend(["Inf", "-Inf"])
+ else:
+ pytest.fail(f"unexpected vendor: {self.conn.info.vendor}")
+
+ return values
+
+ def schema_Enum(self, cls):
+ # TODO: can't fake those as we would need to create temporary types
+ return None
+
+ def make_Enum(self, spec):
+ return None
+
+ def make_float(self, spec, double=True):
+ if random() <= 0.99:
+ # These exponents should generate no inf
+ return float(
+ f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+ if double
+ else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
+ )
+ else:
+ return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
+
+ def match_float(self, spec, got, want, approx=False, rel=None):
+ if got is not None and isnan(got):
+ assert isnan(want)
+ else:
+ if approx or self._server_rounds():
+ assert got == pytest.approx(want, rel=rel)
+ else:
+ assert got == want
+
+ def _server_rounds(self):
+ """Return True if the connected server perform float rounding"""
+ if self.conn.info.vendor == "CockroachDB":
+ return True
+ else:
+ # Versions older than 12 make some rounding. e.g. in Postgres 10.4
+ # select '-1.409006204063909e+112'::float8
+ # -> -1.40900620406391e+112
+ return self.conn.info.server_version < 120000
+
+ def make_Float4(self, spec):
+ return spec(self.make_float(spec, double=False))
+
+ def match_Float4(self, spec, got, want):
+ self.match_float(spec, got, want, approx=True, rel=1e-5)
+
+ def make_Float8(self, spec):
+ return spec(self.make_float(spec))
+
+ match_Float8 = match_float
+
+ def make_int(self, spec):
+ return randrange(-(1 << 90), 1 << 90)
+
+ def make_Int2(self, spec):
+ return spec(randrange(-(1 << 15), 1 << 15))
+
+ def make_Int4(self, spec):
+ return spec(randrange(-(1 << 31), 1 << 31))
+
+ def make_Int8(self, spec):
+ return spec(randrange(-(1 << 63), 1 << 63))
+
+ def make_IntNumeric(self, spec):
+ return spec(randrange(-(1 << 100), 1 << 100))
+
+ def make_IPv4Address(self, spec):
+ return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4)))
+
+ def make_IPv4Interface(self, spec):
+ prefix = randrange(32)
+ return ipaddress.IPv4Interface(
+ (bytes(randrange(256) for _ in range(4)), prefix)
+ )
+
+ def make_IPv4Network(self, spec):
+ return self.make_IPv4Interface(spec).network
+
+ def make_IPv6Address(self, spec):
+ return ipaddress.IPv6Address(bytes(randrange(256) for _ in range(16)))
+
+ def make_IPv6Interface(self, spec):
+ prefix = randrange(128)
+ return ipaddress.IPv6Interface(
+ (bytes(randrange(256) for _ in range(16)), prefix)
+ )
+
+ def make_IPv6Network(self, spec):
+ return self.make_IPv6Interface(spec).network
+
+ def make_Json(self, spec):
+ return spec(self._make_json())
+
+ def match_Json(self, spec, got, want):
+ if want is not None:
+ want = want.obj
+ assert got == want
+
+ def make_Jsonb(self, spec):
+ return spec(self._make_json())
+
+ def match_Jsonb(self, spec, got, want):
+ self.match_Json(spec, got, want)
+
+ def make_JsonFloat(self, spec):
+ # A float limited to what json accepts
+ # this exponent should generate no inf
+ return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}")
+
+ def schema_list(self, cls):
+ while True:
+ scls = choice(self.types)
+ if scls is cls:
+ continue
+ if scls is float:
+ # TODO: float lists are currently adapted as decimal.
+ # There may be rounding errors or problems with inf.
+ continue
+
+ # CRDB doesn't support arrays of json
+ # https://github.com/cockroachdb/cockroach/issues/23468
+ if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb):
+ continue
+
+ schema = self.make_schema(scls)
+ if schema is not None:
+ break
+
+ return (cls, schema)
+
+ def make_list(self, spec):
+ # don't make empty lists because they regularly fail cast
+ length = randrange(1, self.list_max_length)
+ spec = spec[1]
+ while True:
+ rv = [self.make(spec) for i in range(length)]
+
+ # TODO multirange lists fail binary dump if the last element is
+ # empty and there is no type annotation. See xfail in
+ # test_multirange::test_dump_builtin_array
+ if rv and isinstance(rv[-1], Multirange) and not rv[-1]:
+ continue
+
+ return rv
+
+ def example_list(self, spec):
+ return [self.example(spec[1])]
+
+ def match_list(self, spec, got, want):
+ assert len(got) == len(want)
+ m = self.get_matcher(spec[1])
+ for g, w in zip(got, want):
+ m(spec[1], g, w)
+
+ def make_memoryview(self, spec):
+ return self.make_bytes(spec)
+
+ def schema_Multirange(self, cls):
+ return self.schema_Range(cls)
+
+ def make_Multirange(self, spec, length=None, **kwargs):
+ if length is None:
+ length = randrange(0, self.list_max_length)
+
+ def overlap(r1, r2):
+ l1, u1 = r1.lower, r1.upper
+ l2, u2 = r2.lower, r2.upper
+ if l1 is None and l2 is None:
+ return True
+ elif l1 is None:
+ l1 = l2
+ elif l2 is None:
+ l2 = l1
+
+ if u1 is None and u2 is None:
+ return True
+ elif u1 is None:
+ u1 = u2
+ elif u2 is None:
+ u2 = u1
+
+ return l1 <= u2 and l2 <= u1
+
+ out: List[Range[Any]] = []
+ for i in range(length):
+ r = self.make_Range((Range, spec[1]), **kwargs)
+ if r.isempty:
+ continue
+ for r2 in out:
+ if overlap(r, r2):
+ insert = False
+ break
+ else:
+ insert = True
+ if insert:
+ out.append(r) # alternatively, we could merge
+
+ return spec[0](sorted(out))
+
+ def example_Multirange(self, spec):
+ return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Multirange(self, spec):
+ return self.make_Multirange((spec, Int4))
+
+ def make_Int8Multirange(self, spec):
+ return self.make_Multirange((spec, Int8))
+
+ def make_NumericMultirange(self, spec):
+ return self.make_Multirange((spec, Decimal))
+
+ def make_DateMultirange(self, spec):
+ return self.make_Multirange((spec, dt.date))
+
+ def make_TimestampMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, False)))
+
+ def make_TimestamptzMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, True)))
+
+ def match_Multirange(self, spec, got, want):
+ assert len(got) == len(want)
+ for ig, iw in zip(got, want):
+ self.match_Range(spec, ig, iw)
+
+ def match_Int4Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int4), got, want)
+
+ def match_Int8Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int8), got, want)
+
+ def match_NumericMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, Decimal), got, want)
+
+ def match_DateMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, dt.date), got, want)
+
+ def match_TimestampMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, True)), got, want)
+
+ def schema_NoneType(self, cls):
+ return None
+
+ def make_NoneType(self, spec):
+ return None
+
+ def make_Oid(self, spec):
+ return spec(randrange(1 << 32))
+
+ def schema_Range(self, cls):
+ subtypes = [
+ Decimal,
+ Int4,
+ Int8,
+ dt.date,
+ (dt.datetime, True),
+ (dt.datetime, False),
+ ]
+
+ return (cls, choice(subtypes))
+
+ def make_Range(self, spec, empty_chance=0.02, no_bound_chance=0.05):
+ # TODO: drop format check after fixing binary dumping of empty ranges
+ # (an array starting with an empty range will get the wrong type currently)
+ if (
+ random() < empty_chance
+ and spec[0] is Range
+ and self.format == PyFormat.TEXT
+ ):
+ return spec[0](empty=True)
+
+ while True:
+ bounds: List[Union[Any, None]] = []
+ while len(bounds) < 2:
+ if random() < no_bound_chance:
+ bounds.append(None)
+ continue
+
+ val = self.make(spec[1])
+ # NaN are allowed in a range, but comparison in Python get tricky.
+ if spec[1] is Decimal and val.is_nan():
+ continue
+
+ bounds.append(val)
+
+ if bounds[0] is not None and bounds[1] is not None:
+ if bounds[0] == bounds[1]:
+ # It would come out empty
+ continue
+
+ if bounds[0] > bounds[1]:
+ bounds.reverse()
+
+ # avoid generating ranges with no type info if dumping in binary
+ # TODO: lift this limitation after test_copy_in_empty xfail is fixed
+ if spec[0] is Range and self.format == PyFormat.BINARY:
+ if bounds[0] is bounds[1] is None:
+ continue
+
+ break
+
+ r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])"))
+ return r
+
+ def example_Range(self, spec):
+ return self.make_Range(spec, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Range(self, spec):
+ return self.make_Range((spec, Int4))
+
+ def make_Int8Range(self, spec):
+ return self.make_Range((spec, Int8))
+
+ def make_NumericRange(self, spec):
+ return self.make_Range((spec, Decimal))
+
+ def make_DateRange(self, spec):
+ return self.make_Range((spec, dt.date))
+
+ def make_TimestampRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, False)))
+
+ def make_TimestamptzRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, True)))
+
+ def match_Range(self, spec, got, want):
+ # normalise the bounds of unbounded ranges
+ if want.lower is None and want.lower_inc:
+ want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
+ if want.upper is None and want.upper_inc:
+ want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+
+ # Normalise discrete ranges
+ unit: Union[dt.timedelta, int, None]
+ if spec[1] is dt.date:
+ unit = dt.timedelta(days=1)
+ elif type(spec[1]) is type and issubclass(spec[1], int):
+ unit = 1
+ else:
+ unit = None
+
+ if unit is not None:
+ if want.lower is not None and not want.lower_inc:
+ want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1])
+ if want.upper_inc:
+ want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")")
+
+ if spec[1] == (dt.datetime, True) and not want.isempty:
+ # work around https://bugs.python.org/issue45347
+ def fix_dt(x):
+ return x.astimezone(dt.timezone.utc) if x is not None else None
+
+ def fix_range(r):
+ return type(r)(fix_dt(r.lower), fix_dt(r.upper), r.bounds)
+
+ want = fix_range(want)
+ got = fix_range(got)
+
+ assert got == want
+
+ def match_Int4Range(self, spec, got, want):
+ return self.match_Range((spec, Int4), got, want)
+
+ def match_Int8Range(self, spec, got, want):
+ return self.match_Range((spec, Int8), got, want)
+
+ def match_NumericRange(self, spec, got, want):
+ return self.match_Range((spec, Decimal), got, want)
+
+ def match_DateRange(self, spec, got, want):
+ return self.match_Range((spec, dt.date), got, want)
+
+ def match_TimestampRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, True)), got, want)
+
+ def make_str(self, spec, length=0):
+ if not length:
+ length = randrange(self.str_max_length)
+
+ rv: List[int] = []
+ while len(rv) < length:
+ c = randrange(1, 128) if random() < 0.5 else randrange(1, 0x110000)
+ if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF):
+ rv.append(c)
+
+ return "".join(map(chr, rv))
+
+ def schema_time(self, cls):
+ # Choose timezone yes/no
+ return (cls, choice([True, False]))
+
+ def make_time(self, spec):
+ val = randrange(24 * 60 * 60 * 1_000_000)
+ val, ms = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ tz = self._make_tz(spec) if spec[1] else None
+ return dt.time(h, m, s, ms, tz)
+
+ CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239)
+
+ def make_timedelta(self, spec):
+ if self.conn.info.vendor == "CockroachDB":
+ rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX]
+ else:
+ rng = [dt.timedelta.min, dt.timedelta.max]
+
+ return choice(rng) * random()
+
+ def schema_tuple(self, cls):
+ # TODO: this is a complicated matter as it would involve creating
+ # temporary composite types.
+ # length = randrange(1, self.tuple_max_length)
+ # return (cls, self.make_random_schema(ncols=length))
+ return None
+
+ def make_tuple(self, spec):
+ return tuple(self.make(s) for s in spec[1])
+
+ def match_tuple(self, spec, got, want):
+ assert len(got) == len(want) == len(spec[1])
+ for g, w, s in zip(got, want, spec):
+ if g is None or w is None:
+ assert g is w
+ else:
+ m = self.get_matcher(s)
+ m(s, g, w)
+
+ def make_UUID(self, spec):
+ return UUID(bytes=bytes([randrange(256) for i in range(16)]))
+
+ def _make_json(self, container_chance=0.66):
+ rec_types = [list, dict]
+ scal_types = [type(None), int, JsonFloat, bool, str]
+ if random() < container_chance:
+ cls = choice(rec_types)
+ if cls is list:
+ return [
+ self._make_json(container_chance=container_chance / 2.0)
+ for i in range(randrange(self.json_max_length))
+ ]
+ elif cls is dict:
+ return {
+ self.make_str(str, 15): self._make_json(
+ container_chance=container_chance / 2.0
+ )
+ for i in range(randrange(self.json_max_length))
+ }
+ else:
+ assert False, f"unknown rec type: {cls}"
+
+ else:
+ cls = choice(scal_types) # type: ignore[assignment]
+ return self.make(cls)
+
+ def _make_tz(self, spec):
+ minutes = randrange(-12 * 60, 12 * 60 + 1)
+ return dt.timezone(dt.timedelta(minutes=minutes))
+
+
+class JsonFloat:
+ pass
+
+
+def deep_import(name):
+ parts = Deque(name.split("."))
+ seen = []
+ if not parts:
+ raise ValueError("name must be a dot-separated name")
+
+ seen.append(parts.popleft())
+ thing = importlib.import_module(seen[-1])
+ while parts:
+ attr = parts.popleft()
+ seen.append(attr)
+
+ if hasattr(thing, attr):
+ thing = getattr(thing, attr)
+ else:
+ thing = importlib.import_module(".".join(seen))
+
+ return thing
diff --git a/tests/fix_mypy.py b/tests/fix_mypy.py
new file mode 100644
index 0000000..b860a32
--- /dev/null
+++ b/tests/fix_mypy.py
@@ -0,0 +1,54 @@
+import re
+import subprocess as sp
+
+import pytest
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers",
+ "mypy: the test uses mypy (the marker is set automatically"
+ " on tests using the fixture)",
+ )
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ if "mypy" in item.fixturenames:
+ # add a mypy tag so we can address these tests only
+ item.add_marker(pytest.mark.mypy)
+
+ # All the tests using mypy are slow
+ item.add_marker(pytest.mark.slow)
+
+
+@pytest.fixture(scope="session")
+def mypy(tmp_path_factory):
+ cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
+ src_dir = tmp_path_factory.mktemp("source")
+
+ class MypyRunner:
+ def run_on_file(self, filename):
+ cmdline = f"""
+ mypy
+ --strict
+ --show-error-codes --no-color-output --no-error-summary
+ --config-file= --cache-dir={cache_dir}
+ """.split()
+ cmdline.append(filename)
+ return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
+
+ def run_on_source(self, source):
+ fn = src_dir / "tmp.py"
+ with fn.open("w") as f:
+ f.write(source)
+
+ return self.run_on_file(str(fn))
+
+ def get_revealed(self, line):
+ """return the type from an output of reveal_type"""
+ return re.sub(
+ r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line
+ ).replace("*", "")
+
+ return MypyRunner()
diff --git a/tests/fix_pq.py b/tests/fix_pq.py
new file mode 100644
index 0000000..6811a26
--- /dev/null
+++ b/tests/fix_pq.py
@@ -0,0 +1,141 @@
+import os
+import sys
+import ctypes
+from typing import Iterator, List, NamedTuple
+from tempfile import TemporaryFile
+
+import pytest
+
+from .utils import check_libpq_version
+
+try:
+ from psycopg import pq
+except ImportError:
+ pq = None # type: ignore
+
+
+def pytest_report_header(config):
+ try:
+ from psycopg import pq
+ except ImportError:
+ return []
+
+ return [
+ f"libpq wrapper implementation: {pq.__impl__}",
+ f"libpq used: {pq.version()}",
+ f"libpq compiled: {pq.__build_version__}",
+ ]
+
+
+def pytest_configure(config):
+ # register libpq marker
+ config.addinivalue_line(
+ "markers",
+ "libpq(version_expr): run the test only with matching libpq"
+ " (e.g. '>= 10', '< 9.6')",
+ )
+
+
+def pytest_runtest_setup(item):
+ for m in item.iter_markers(name="libpq"):
+ assert len(m.args) == 1
+ msg = check_libpq_version(pq.version(), m.args[0])
+ if msg:
+ pytest.skip(msg)
+
+
+@pytest.fixture
+def libpq():
+ """Return a ctypes wrapper to access the libpq."""
+ try:
+ from psycopg.pq.misc import find_libpq_full_path
+
+ # Not available when testing the binary package
+ libname = find_libpq_full_path()
+ assert libname, "libpq libname not found"
+ return ctypes.pydll.LoadLibrary(libname)
+ except Exception as e:
+ if pq.__impl__ == "binary":
+ pytest.skip(f"can't load libpq for testing: {e}")
+ else:
+ raise
+
+
+@pytest.fixture
+def setpgenv(monkeypatch):
+ """Replace the PG* env vars with the vars provided."""
+
+ def setpgenv_(env):
+ ks = [k for k in os.environ if k.startswith("PG")]
+ for k in ks:
+ monkeypatch.delenv(k)
+
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+
+ return setpgenv_
+
+
+@pytest.fixture
+def trace(libpq):
+ pqver = pq.__build_version__
+ if pqver < 140000:
+ pytest.skip(f"trace not available on libpq {pqver}")
+ if sys.platform != "linux":
+ pytest.skip(f"trace not available on {sys.platform}")
+
+ yield Tracer()
+
+
+class Tracer:
+ def trace(self, conn):
+ pgconn: "pq.abc.PGconn"
+
+ if hasattr(conn, "exec_"):
+ pgconn = conn
+ elif hasattr(conn, "cursor"):
+ pgconn = conn.pgconn
+ else:
+ raise Exception()
+
+ return TraceLog(pgconn)
+
+
+class TraceLog:
+ def __init__(self, pgconn: "pq.abc.PGconn"):
+ self.pgconn = pgconn
+ self.tempfile = TemporaryFile(buffering=0)
+ pgconn.trace(self.tempfile.fileno())
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS)
+
+ def __del__(self):
+ if self.pgconn.status == pq.ConnStatus.OK:
+ self.pgconn.untrace()
+ self.tempfile.close()
+
+ def __iter__(self) -> "Iterator[TraceEntry]":
+ self.tempfile.seek(0)
+ data = self.tempfile.read()
+ for entry in self._parse_entries(data):
+ yield entry
+
+ def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
+ for line in data.splitlines():
+ direction, length, type, *content = line.split(b"\t")
+ yield TraceEntry(
+ direction=direction.decode(),
+ length=int(length.decode()),
+ type=type.decode(),
+ # Note: the items encoding is not very solid: no escaped
+ # backslash, no escaped quotes.
+ # At the moment we don't need a proper parser.
+ content=[content[0]] if content else [],
+ )
+
+
+class TraceEntry(NamedTuple):
+ direction: str
+ length: int
+ type: str
+ content: List[bytes]
diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py
new file mode 100644
index 0000000..e50f5ec
--- /dev/null
+++ b/tests/fix_proxy.py
@@ -0,0 +1,127 @@
+import os
+import time
+import socket
+import logging
+import subprocess as sp
+from shutil import which
+
+import pytest
+
+import psycopg
+from psycopg import conninfo
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ # TODO: there is a race condition on macOS and Windows in the CI:
+ # listen returns before really listening and tests based on 'deaf_port'
+ # fail 50% of the times. Just add the 'proxy' mark on these tests
+ # because they are already skipped in the CI.
+ if "proxy" in item.fixturenames or "deaf_port" in item.fixturenames:
+ item.add_marker(pytest.mark.proxy)
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers",
+ "proxy: the test uses pproxy (the marker is set automatically"
+ " on tests using the fixture)",
+ )
+
+
+@pytest.fixture
+def proxy(dsn):
+ """Return a proxy to the --test-dsn database"""
+ p = Proxy(dsn)
+ yield p
+ p.stop()
+
+
+@pytest.fixture
+def deaf_port(dsn):
+ """Return a port number with a socket open but not answering"""
+ with socket.socket(socket.AF_INET) as s:
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ s.listen(0)
+ yield port
+
+
+class Proxy:
+ """
+ Proxy a Postgres service for testing purpose.
+
+ Allow to lose connectivity and restart it using stop/start.
+ """
+
+ def __init__(self, server_dsn):
+ cdict = conninfo.conninfo_to_dict(server_dsn)
+
+ # Get server params
+ host = cdict.get("host") or os.environ.get("PGHOST")
+ self.server_host = host if host and not host.startswith("/") else "localhost"
+ self.server_port = cdict.get("port", "5432")
+
+ # Get client params
+ self.client_host = "localhost"
+ self.client_port = self._get_random_port()
+
+ # Make a connection string to the proxy
+ cdict["host"] = self.client_host
+ cdict["port"] = self.client_port
+ cdict["sslmode"] = "disable" # not supported by the proxy
+ self.client_dsn = conninfo.make_conninfo(**cdict)
+
+ # The running proxy process
+ self.proc = None
+
+ def start(self):
+ if self.proc:
+ logging.info("proxy already started")
+ return
+
+ logging.info("starting proxy")
+ pproxy = which("pproxy")
+ if not pproxy:
+ raise ValueError("pproxy program not found")
+ cmdline = [pproxy, "--reuse"]
+ cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
+ cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"])
+
+ self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL)
+ logging.info("proxy started")
+ self._wait_listen()
+
+ # verify that the proxy works
+ try:
+ with psycopg.connect(self.client_dsn):
+ pass
+ except Exception as e:
+ pytest.fail(f"failed to create a working proxy: {e}")
+
+ def stop(self):
+ if not self.proc:
+ return
+
+ logging.info("stopping proxy")
+ self.proc.terminate()
+ self.proc.wait()
+ logging.info("proxy stopped")
+ self.proc = None
+
+ @classmethod
+ def _get_random_port(cls):
+ with socket.socket() as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+ def _wait_listen(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ for i in range(20):
+ if 0 == sock.connect_ex((self.client_host, self.client_port)):
+ break
+ time.sleep(0.1)
+ else:
+ raise ValueError("the proxy didn't start listening in time")
+
+ logging.info("proxy listening")
diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py
new file mode 100644
index 0000000..80e0c62
--- /dev/null
+++ b/tests/fix_psycopg.py
@@ -0,0 +1,98 @@
+from copy import deepcopy
+
+import pytest
+
+
+@pytest.fixture
+def global_adapters():
+ """Restore the global adapters after a test has changed them."""
+ from psycopg import adapters
+
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+@pytest.fixture
+@pytest.mark.crdb_skip("2-phase commit")
+def tpc(svcconn):
+ tpc = Tpc(svcconn)
+ tpc.check_tpc()
+ tpc.clear_test_xacts()
+ tpc.make_test_table()
+ yield tpc
+ tpc.clear_test_xacts()
+
+
+class Tpc:
+ """Helper object to test two-phase transactions"""
+
+ def __init__(self, conn):
+ assert conn.autocommit
+ self.conn = conn
+
+ def check_tpc(self):
+ from .fix_crdb import is_crdb, crdb_skip_message
+
+ if is_crdb(self.conn):
+ pytest.skip(crdb_skip_message("2-phase commit"))
+
+ val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0])
+ if not val:
+ pytest.skip("prepared transactions disabled in the database")
+
+ def clear_test_xacts(self):
+ """Rollback all the prepared transaction in the testing db."""
+ from psycopg import sql
+
+ cur = self.conn.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (self.conn.info.dbname,),
+ )
+ gids = [r[0] for r in cur]
+ for gid in gids:
+ self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+ def make_test_table(self):
+ self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+ self.conn.execute("TRUNCATE test_tpc")
+
+ def count_xacts(self):
+ """Return the number of prepared xacts currently in the test db."""
+ cur = self.conn.execute(
+ """
+ select count(*) from pg_prepared_xacts
+ where database = %s""",
+ (self.conn.info.dbname,),
+ )
+ return cur.fetchone()[0]
+
+ def count_test_records(self):
+ """Return the number of records in the test table."""
+ cur = self.conn.execute("select count(*) from test_tpc")
+ return cur.fetchone()[0]
+
+
+@pytest.fixture(scope="module")
+def generators():
+ """Return the 'generators' module for selected psycopg implementation."""
+ from psycopg import pq
+
+ if pq.__impl__ == "c":
+ from psycopg._cmodule import _psycopg
+
+ return _psycopg
+ else:
+ import psycopg.generators
+
+ return psycopg.generators
diff --git a/tests/pool/__init__.py b/tests/pool/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/pool/__init__.py
diff --git a/tests/pool/fix_pool.py b/tests/pool/fix_pool.py
new file mode 100644
index 0000000..12e4f39
--- /dev/null
+++ b/tests/pool/fix_pool.py
@@ -0,0 +1,12 @@
+import pytest
+
+
+def pytest_configure(config):
+ config.addinivalue_line("markers", "pool: test related to the psycopg_pool package")
+
+
+def pytest_collection_modifyitems(items):
+ # Add the pool markers to all the tests in the pool package
+ for item in items:
+ if "/pool/" in item.nodeid:
+ item.add_marker(pytest.mark.pool)
diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py
new file mode 100644
index 0000000..c0e8060
--- /dev/null
+++ b/tests/pool/test_null_pool.py
@@ -0,0 +1,896 @@
+import logging
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+from .test_pool import delay_connection, ensure_waiting
+
+try:
+ from psycopg_pool import NullConnectionPool
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+def test_defaults(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+def test_min_size_max_size(dsn):
+ with NullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ NullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+ class MyConn(psycopg.Connection[Any]):
+ pass
+
+ with NullConnectionPool(dsn, connection_class=MyConn) as p:
+ with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+ with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_its_no_pool_at_all(dsn):
+ with NullConnectionPool(dsn, max_size=2) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ with p.connection() as conn:
+ assert conn.info.backend_pid not in (pid1, pid2)
+
+
+def test_context(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.1)
+
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.4)
+
+
+def test_wait_closed(dsn):
+ with NullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ p.wait(0.2)
+
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+
+def test_configure(dsn):
+ inits = 0
+
+ def configure(conn):
+ nonlocal inits
+ inits += 1
+ with conn.transaction():
+ conn.execute("set default_transaction_read_only to on")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+ with p.connection() as conn:
+ assert inits == 2
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+ conn.close()
+
+ with p.connection() as conn:
+ assert inits == 3
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ conn.execute("select 1")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ assert resets == 1
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+ p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+def test_no_queue_timeout(deaf_port):
+ with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
+ with pytest.raises(PoolTimeout):
+ with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue(dsn):
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ with NullConnectionPool(dsn, max_size=2) as p:
+ p.wait()
+ ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+ def worker(t, ev=None):
+ try:
+ with p.connection():
+ if ev:
+ ev.set()
+ sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ p.wait()
+ ev = Event()
+ t = Thread(target=worker, args=(0.3, ev))
+ t.start()
+ ev.wait()
+
+ ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout(dsn):
+ def worker(n):
+ t0 = time()
+ try:
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ results: List[int] = []
+
+ with NullConnectionPool(dsn, max_size=2) as p:
+ ts = [
+ Thread(target=worker, args=(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout_override(dsn):
+ def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_broken_reconnect(dsn):
+ with NullConnectionPool(dsn, max_size=1) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ conn.close()
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert not conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ t = Thread(target=worker)
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ t = Thread(target=worker, args=(p,))
+ t.start()
+ ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner and p._sched_runner.is_alive()
+ workers = p._workers[:]
+ assert workers
+ for t in workers:
+ assert t.is_alive()
+
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert not t.is_alive()
+
+
+def test_putconn_no_pool(conn_cls, dsn):
+ with NullConnectionPool(dsn) as p:
+ conn = conn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+ conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+ with NullConnectionPool(dsn) as p1:
+ with NullConnectionPool(dsn) as p2:
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner is not None
+ ts = [p._sched_runner] + p._workers
+ del p
+ sleep(0.1)
+ for t in ts:
+ assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+ p = NullConnectionPool(dsn)
+ assert not p.closed
+ with p.connection():
+ pass
+
+ p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+
+def test_closed_putconn(dsn):
+ p = NullConnectionPool(dsn)
+
+ with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ with p.connection() as conn:
+ p.close()
+ assert conn.closed
+
+
+def test_closed_queue(dsn):
+ def w1():
+ with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+ e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ def w2():
+ try:
+ with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = Event()
+ e2 = Event()
+
+ p = NullConnectionPool(dsn, max_size=1)
+ p.wait()
+ success: List[str] = []
+
+ t1 = Thread(target=w1)
+ t1.start()
+ # Wait until w1 has received a connection
+ e1.wait()
+
+ t2 = Thread(target=w2)
+ t2.start()
+ # Wait until w2 is in the queue
+ ensure_waiting(p)
+
+ p.close(0)
+
+ # Wait for the workers to finish
+ e2.set()
+ t1.join()
+ t2.join()
+ assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ p.getconn()
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+ p.open()
+ try:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ p.getconn()
+
+
+def test_open_context(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ with p:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+def test_open_no_op(dsn):
+ p = NullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ p.open()
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+
+def test_reopen(dsn):
+ p = NullConnectionPool(dsn)
+ with p.connection() as conn:
+ conn.execute("select 1")
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ p.open()
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_resize(dsn, min_size, max_size):
+ with NullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_max_lifetime(dsn):
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ sleep(0.1)
+
+ ts = []
+ with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ for i in range(5):
+ ts.append(Thread(target=worker, args=(p,)))
+ ts[-1].start()
+
+ for t in ts:
+ t.join()
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+def test_check(dsn):
+ with NullConnectionPool(dsn) as p:
+ # No-op
+ p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+
+ with NullConnectionPool(dsn, max_size=4) as p:
+ p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ p.wait(2.0)
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn):
+ def worker(n):
+ try:
+ with p.connection(timeout=0.3) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ with NullConnectionPool(dsn, max_size=3) as p:
+ p.wait(2.0)
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ with p.connection() as conn:
+ conn.close()
+ p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ with NullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py
new file mode 100644
index 0000000..23a1a52
--- /dev/null
+++ b/tests/pool/test_null_pool_async.py
@@ -0,0 +1,844 @@
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task
+from .test_pool_async import delay_connection, ensure_waiting
+
+pytestmark = [pytest.mark.asyncio]
+
+try:
+ from psycopg_pool import AsyncNullConnectionPool # noqa: F401
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+async def test_defaults(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+async def test_min_size_max_size(dsn):
+ async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+async def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+ class MyConn(psycopg.AsyncConnection[Any]):
+ pass
+
+ async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p:
+ async with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+ async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_its_no_pool_at_all(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ async with p.connection() as conn:
+ assert conn.info.backend_pid not in (pid1, pid2)
+
+
+async def test_context(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.1)
+
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.4)
+
+
+async def test_wait_closed(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ await p.wait(0.2)
+
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+ inits = 0
+
+ async def configure(conn):
+ nonlocal inits
+ inits += 1
+ async with conn.transaction():
+ await conn.execute("set default_transaction_read_only to on")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+ async with p.connection() as conn:
+ assert inits == 2
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+ await conn.close()
+
+ async with p.connection() as conn:
+ assert inits == 3
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ await conn.execute("select 1")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ assert resets == 1
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+ await p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ await conn.execute("select 1")
+ pids.append(conn.info.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+async def test_no_queue_timeout(deaf_port):
+ async with AsyncNullConnectionPool(
+ kwargs={"host": "localhost", "port": deaf_port}
+ ) as p:
+ with pytest.raises(PoolTimeout):
+ async with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue(dsn):
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ await p.wait()
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+ async def worker(t, ev=None):
+ try:
+ async with p.connection():
+ if ev:
+ ev.set()
+ await asyncio.sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ await p.wait()
+ ev = asyncio.Event()
+ create_task(worker(0.3, ev))
+ await ev.wait()
+
+ ts = [create_task(worker(0.1)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout(dsn):
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ results: List[int] = []
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout_override(dsn):
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_broken_reconnect(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ await conn.close()
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ cur = await conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+async def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ pids.append(conn.info.backend_pid)
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+ t = create_task(worker())
+ await ensure_waiting(p)
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pids.append(conn.info.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert p._sched_runner and not p._sched_runner.done()
+ assert p._workers
+ workers = p._workers[:]
+ for t in workers:
+ assert not t.done()
+
+ await p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert t.done()
+
+
+async def test_putconn_no_pool(aconn_cls, dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ conn = await aconn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+
+ await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ async with AsyncNullConnectionPool(dsn) as p1:
+ async with AsyncNullConnectionPool(dsn) as p2:
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+
+ async with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+async def test_closed_queue(dsn):
+ async def w1():
+ async with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ await e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ async def w2():
+ try:
+ async with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = asyncio.Event()
+ e2 = asyncio.Event()
+
+ p = AsyncNullConnectionPool(dsn, max_size=1)
+ await p.wait()
+ success: List[str] = []
+
+ t1 = create_task(w1())
+ # Wait until w1 has received a connection
+ await e1.wait()
+
+ t2 = create_task(w2())
+ # Wait until w2 is in the queue
+ await ensure_waiting(p)
+ await p.close()
+
+ # Wait for the workers to finish
+ e2.set()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed):
+ await p.getconn()
+
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ async with p.connection():
+ pass
+
+ await p.open()
+ try:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ await p.getconn()
+
+
+async def test_open_context(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ async with p:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+async def test_open_no_op(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ await p.open()
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+
+async def test_reopen(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ await p.close()
+ assert p._sched_runner is None
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ await p.open()
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+async def test_bad_resize(dsn, min_size, max_size):
+ async with AsyncNullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ await p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_max_lifetime(dsn):
+ pids: List[int] = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ await asyncio.sleep(0.1)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ ts = [create_task(worker()) for i in range(5)]
+ await asyncio.gather(*ts)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+async def test_check(dsn):
+ # no.op
+ async with AsyncNullConnectionPool(dsn) as p:
+ await p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+
+ async with AsyncNullConnectionPool(dsn, max_size=4) as p:
+ await p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [create_task(worker(i)) for i in range(3)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ await p.wait(2.0)
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn):
+ async def worker(n):
+ try:
+ async with p.connection(timeout=0.3) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ async with AsyncNullConnectionPool(dsn, max_size=3) as p:
+ await p.wait(2.0)
+
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.gather(*ts)
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ async with p.connection() as conn:
+ await conn.close()
+ await p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ async with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ await p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py
new file mode 100644
index 0000000..30c790b
--- /dev/null
+++ b/tests/pool/test_pool.py
@@ -0,0 +1,1265 @@
+import logging
+import weakref
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import Counter
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+
+def test_package_version(mypy):
+ cp = mypy.run_on_source(
+ """\
+from psycopg_pool import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
+
+
+def test_defaults(dsn):
+ with pool.ConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 4
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
+def test_min_size_max_size(dsn, min_size, max_size):
+ with pool.ConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
+ assert p.min_size == min_size
+ assert p.max_size == max_size if max_size is not None else min_size
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
+def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ pool.ConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+ class MyConn(psycopg.Connection[Any]):
+ pass
+
+ with pool.ConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
+ with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+ with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p:
+ with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_its_really_a_pool(dsn):
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ with p.connection() as conn:
+ assert conn.info.backend_pid in (pid1, pid2)
+
+
+def test_context(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_connection_not_lost(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ with pytest.raises(ZeroDivisionError):
+ with p.connection() as conn:
+ pid = conn.info.backend_pid
+ 1 / 0
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_concurrent_filling(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+
+ def add_time(self, conn):
+ times.append(time() - t0)
+ add_orig(self, conn)
+
+ add_orig = pool.ConnectionPool._add_to_pool
+ monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", add_time)
+
+ times: List[float] = []
+ t0 = time()
+
+ with pool.ConnectionPool(dsn, min_size=5, num_workers=2) as p:
+ p.wait(1.0)
+ want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+ assert len(times) == len(want_times)
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.wait(0.3)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.wait(0.5)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=2) as p:
+ p.wait(0.3)
+ p.wait(0.0001) # idempotent
+
+
+def test_wait_closed(dsn):
+ with pool.ConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(pool.PoolClosed):
+ p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
+ p.wait(0.2)
+
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
+ sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+
+def test_configure(dsn):
+ inits = 0
+
+ def configure(conn):
+ nonlocal inits
+ inits += 1
+ with conn.transaction():
+ conn.execute("set default_transaction_read_only to on")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ p.wait()
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+ conn.close()
+
+ with p.connection() as conn:
+ assert inits == 2
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ conn.execute("select 1")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+
+ p.wait()
+ assert resets == 1
+
+ with p.connection() as conn:
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+
+ p.wait()
+ assert resets == 2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue(dsn):
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ p.wait()
+ ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+ def worker(t, ev=None):
+ try:
+ with p.connection():
+ if ev:
+ ev.set()
+ sleep(t)
+ except pool.TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p:
+ p.wait()
+ ev = Event()
+ t = Thread(target=worker, args=(0.3, ev))
+ t.start()
+ ev.wait()
+
+ ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], pool.TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout(dsn):
+ def worker(n):
+ t0 = time()
+ try:
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ results: List[int] = []
+
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ ts = [
+ Thread(target=worker, args=(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_queue_timeout_override(dsn):
+ def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_broken_reconnect(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ conn.close()
+
+ with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+ assert not conn2.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+ pid = conn.info.backend_pid
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = p.getconn()
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+ p = pool.ConnectionPool(dsn)
+ assert p._sched_runner and p._sched_runner.is_alive()
+ workers = p._workers[:]
+ assert workers
+ for t in workers:
+ assert t.is_alive()
+
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert not t.is_alive()
+
+
+def test_putconn_no_pool(conn_cls, dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p:
+ conn = conn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+ conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+ with pool.ConnectionPool(dsn, min_size=1) as p1:
+ with pool.ConnectionPool(dsn, min_size=1) as p2:
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)
+
+
+def test_del_no_warning(dsn, recwarn):
+ p = pool.ConnectionPool(dsn, min_size=2)
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ p.wait()
+ ref = weakref.ref(p)
+ del p
+ assert not ref()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+ p = pool.ConnectionPool(dsn)
+ assert p._sched_runner is not None
+ ts = [p._sched_runner] + p._workers
+ del p
+ sleep(0.1)
+ for t in ts:
+ assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+ p = pool.ConnectionPool(dsn, min_size=1)
+ assert not p.closed
+ with p.connection():
+ pass
+
+ p.close()
+ assert p.closed
+
+ with pytest.raises(pool.PoolClosed):
+ with p.connection():
+ pass
+
+
+def test_closed_putconn(dsn):
+ p = pool.ConnectionPool(dsn, min_size=1)
+
+ with p.connection() as conn:
+ pass
+ assert not conn.closed
+
+ with p.connection() as conn:
+ p.close()
+ assert conn.closed
+
+
+def test_closed_queue(dsn):
+ def w1():
+ with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+ e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ def w2():
+ try:
+ with p.connection():
+ pass # unexpected
+ except pool.PoolClosed:
+ success.append("w2")
+
+ e1 = Event()
+ e2 = Event()
+
+ p = pool.ConnectionPool(dsn, min_size=1)
+ p.wait()
+ success: List[str] = []
+
+ t1 = Thread(target=w1)
+ t1.start()
+ # Wait until w1 has received a connection
+ e1.wait()
+
+ t2 = Thread(target=w2)
+ t2.start()
+ # Wait until w2 is in the queue
+ ensure_waiting(p)
+
+ p.close(0)
+
+ # Wait for the workers to finish
+ e2.set()
+ t1.join()
+ t2.join()
+ assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+ p = pool.ConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(pool.PoolClosed, match="is not open yet"):
+ p.getconn()
+
+ with pytest.raises(pool.PoolClosed):
+ with p.connection():
+ pass
+
+ p.open()
+ try:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+ with pytest.raises(pool.PoolClosed, match="is already closed"):
+ p.getconn()
+
+
+def test_open_context(dsn):
+ p = pool.ConnectionPool(dsn, open=False)
+ assert p.closed
+
+ with p:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+def test_open_no_op(dsn):
+ p = pool.ConnectionPool(dsn)
+ try:
+ assert not p.closed
+ p.open()
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_open_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ p.open(wait=True, timeout=0.3)
+ finally:
+ p.close()
+
+ p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ p.open(wait=True, timeout=0.5)
+ finally:
+ p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_open_as_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.open(wait=True, timeout=0.3)
+
+ with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ p.open(wait=True, timeout=0.5)
+
+
+def test_reopen(dsn):
+ p = pool.ConnectionPool(dsn)
+ with p.connection() as conn:
+ conn.execute("select 1")
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ p.open()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize(
+ "min_size, want_times",
+ [
+ (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]),
+ (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]),
+ ],
+)
+def test_grow(dsn, monkeypatch, min_size, want_times):
+ delay_connection(monkeypatch, 0.1)
+
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select 1 from pg_sleep(0.25)")
+ t1 = time()
+ results.append((n, t1 - t0))
+
+ with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p:
+ p.wait(1.0)
+ results: List[Tuple[int, float]] = []
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_shrink(dsn, monkeypatch):
+
+ from psycopg_pool.pool import ShrinkPool
+
+ results: List[Tuple[int, int]] = []
+
+ def run_hacked(self, pool):
+ n0 = pool._nconns
+ orig_run(self, pool)
+ n1 = pool._nconns
+ results.append((n0, n1))
+
+ orig_run = ShrinkPool._run
+ monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
+
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.1)")
+
+ with pool.ConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
+ p.wait(5.0)
+ assert p.max_idle == 0.2
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(1)
+
+ assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+def test_reconnect(proxy, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+ caplog.clear()
+ proxy.start()
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1) as p:
+ p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ sleep(1.0)
+ proxy.start()
+ p.wait()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ assert "BAD" in caplog.messages[0]
+ times = [rec.created for rec in caplog.records]
+ assert times[1] - times[0] < 0.05
+ deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+ assert len(deltas) == 3
+ want = 0.1
+ for delta in deltas:
+ assert delta == pytest.approx(want, 0.05), deltas
+ want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_reconnect_failure(proxy):
+ proxy.start()
+
+ t1 = None
+
+ def failed(pool):
+ assert pool.name == "this-one"
+ nonlocal t1
+ t1 = time()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn,
+ name="this-one",
+ min_size=1,
+ reconnect_timeout=1.0,
+ reconnect_failed=failed,
+ ) as p:
+ p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+ t0 = time()
+ sleep(1.5)
+ assert t1
+ assert t1 - t0 == pytest.approx(1.0, 0.1)
+ assert p._nconns == 0
+
+ proxy.start()
+ t0 = time()
+ with p.connection() as conn:
+ conn.execute("select 1")
+ t1 = time()
+ assert t1 - t0 < 0.2
+
+
+@pytest.mark.slow
+def test_reconnect_after_grow_failed(proxy):
+ # Retry reconnection after a failed connection attempt has put the pool
+ # in grow mode. See issue #370.
+ proxy.stop()
+
+ ev = Event()
+
+ def failed(pool):
+ ev.set()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ assert ev.wait(timeout=2)
+
+ with pytest.raises(pool.PoolTimeout):
+ with p.connection(timeout=0.5) as conn:
+ pass
+
+ ev.clear()
+ assert ev.wait(timeout=2)
+
+ proxy.start()
+
+ with p.connection(timeout=2) as conn:
+ conn.execute("select 1")
+
+ p.wait(timeout=3.0)
+ assert len(p._pool) == p.min_size == 4
+
+
+@pytest.mark.slow
+def test_refill_on_check(proxy):
+ proxy.start()
+ ev = Event()
+
+ def failed(pool):
+ ev.set()
+
+ with pool.ConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ # The pool is full
+ p.wait(timeout=2)
+
+ # Break all the connection
+ proxy.stop()
+
+ # Checking the pool will empty it
+ p.check()
+ assert ev.wait(timeout=2)
+ assert len(p._pool) == 0
+
+ # Allow to connect again
+ proxy.start()
+
+ # Make sure that check has refilled the pool
+ p.check()
+ p.wait(timeout=2)
+ assert len(p._pool) == 4
+
+
+@pytest.mark.slow
+def test_uniform_use(dsn):
+ with pool.ConnectionPool(dsn, min_size=4) as p:
+ counts = Counter[int]()
+ for i in range(8):
+ with p.connection() as conn:
+ sleep(0.1)
+ counts[id(conn)] += 1
+
+ assert len(counts) == 4
+ assert set(counts.values()) == set([2])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_resize(dsn):
+ def sampler():
+ sleep(0.05) # ensure sampling happens after shrink check
+ while True:
+ sleep(0.2)
+ if p.closed:
+ break
+ size.append(len(p._pool))
+
+ def client(t):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(%s)", [t])
+
+ size: List[int] = []
+
+ with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
+ s = Thread(target=sampler)
+ s.start()
+
+ sleep(0.3)
+ c = Thread(target=client, args=(0.4,))
+ c.start()
+
+ sleep(0.2)
+ p.resize(4)
+ assert p.min_size == 4
+ assert p.max_size == 4
+
+ sleep(0.4)
+ p.resize(2)
+ assert p.min_size == 2
+ assert p.max_size == 2
+
+ sleep(0.6)
+
+ s.join()
+ assert size == [2, 1, 3, 4, 3, 2, 2]
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)])
+def test_bad_resize(dsn, min_size, max_size):
+ with pool.ConnectionPool() as p:
+ with pytest.raises(ValueError):
+ p.resize(min_size=min_size, max_size=max_size)
+
+
+def test_jitter():
+ rnds = [pool.ConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
+ assert 27 <= min(rnds) <= 28
+ assert 35 < max(rnds) < 36
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_max_lifetime(dsn):
+ with pool.ConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
+ sleep(0.1)
+ pids = []
+ for i in range(5):
+ with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ sleep(0.2)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_check(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ with pool.ConnectionPool(dsn, min_size=4) as p:
+ p.wait(1.0)
+ with p.connection() as conn:
+ pid = conn.info.backend_pid
+
+ p.wait(1.0)
+ pids = set(conn.info.backend_pid for conn in p._pool)
+ assert pid in pids
+ conn.close()
+
+ assert len(caplog.records) == 0
+ p.check()
+ assert len(caplog.records) == 1
+ p.wait(1.0)
+ pids2 = set(conn.info.backend_pid for conn in p._pool)
+ assert len(pids & pids2) == 3
+ assert pid not in pids2
+
+
+def test_check_idle(dsn):
+ with pool.ConnectionPool(dsn, min_size=2) as p:
+ p.wait(1.0)
+ p.check()
+ with p.connection() as conn:
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+
+ with pool.ConnectionPool(dsn, min_size=2, max_size=4) as p:
+ p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 2
+ assert stats["pool_available"] == 2
+ assert stats["requests_waiting"] == 0
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ p.wait(2.0)
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn):
+ def worker(n):
+ try:
+ with p.connection(timeout=0.3) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ except pool.PoolTimeout:
+ pass
+
+ with pool.ConnectionPool(dsn, min_size=3) as p:
+ p.wait(2.0)
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ with p.connection() as conn:
+ conn.close()
+ p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ with pool.ConnectionPool(proxy.client_dsn, min_size=3) as p:
+ p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 3
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 600 <= stats["connections_ms"] < 1200
+
+ proxy.stop()
+ p.check()
+ sleep(0.1)
+ stats = p.get_stats()
+ assert stats["connections_num"] > 3
+ assert stats["connections_errors"] > 0
+ assert stats["connections_lost"] == 3
+
+
+@pytest.mark.slow
+def test_spike(dsn, monkeypatch):
+ # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/
+ # documents/Welcome-To-The-Jungle.md
+ delay_connection(monkeypatch, 0.15)
+
+ def worker():
+ with p.connection():
+ sleep(0.002)
+
+ with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p:
+ p.wait()
+
+ ts = [Thread(target=worker) for i in range(50)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ p.wait()
+
+ assert len(p._pool) < 7
+
+
+def test_debug_deadlock(dsn):
+ # https://github.com/psycopg/psycopg/issues/230
+ logger = logging.getLogger("psycopg")
+ handler = logging.StreamHandler()
+ old_level = logger.level
+ logger.setLevel(logging.DEBUG)
+ handler.setLevel(logging.DEBUG)
+ logger.addHandler(handler)
+ try:
+ with pool.ConnectionPool(dsn, min_size=4, open=True) as p:
+ try:
+ p.wait(timeout=2)
+ finally:
+ print(p.get_stats())
+ finally:
+ logger.removeHandler(handler)
+ logger.setLevel(old_level)
+
+
+def delay_connection(monkeypatch, sec):
+ """
+ Return a _connect_gen function delayed by the amount of seconds
+ """
+
+ def connect_delay(*args, **kwargs):
+ t0 = time()
+ rv = connect_orig(*args, **kwargs)
+ t1 = time()
+ sleep(max(0, sec - (t1 - t0)))
+ return rv
+
+ connect_orig = psycopg.Connection.connect
+ monkeypatch.setattr(psycopg.Connection, "connect", connect_delay)
+
+
+def ensure_waiting(p, num=1):
+ """
+ Wait until there are at least *num* clients waiting in the queue.
+ """
+ while len(p._waiting) < num:
+ sleep(0)
diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py
new file mode 100644
index 0000000..286a775
--- /dev/null
+++ b/tests/pool/test_pool_async.py
@@ -0,0 +1,1198 @@
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task, Counter
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.asyncio]
+
+
+async def test_defaults(dsn):
+ async with pool.AsyncConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 4
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
+async def test_min_size_max_size(dsn, min_size, max_size):
+ async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
+ assert p.min_size == min_size
+ assert p.max_size == max_size if max_size is not None else min_size
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
+async def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ pool.AsyncConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+ class MyConn(psycopg.AsyncConnection[Any]):
+ pass
+
+ async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
+ async with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+ async with pool.AsyncConnectionPool(
+ dsn, kwargs={"autocommit": True}, min_size=1
+ ) as p:
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_its_really_a_pool(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ async with p.connection() as conn:
+ assert conn.info.backend_pid in (pid1, pid2)
+
+
+async def test_context(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_connection_not_lost(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ with pytest.raises(ZeroDivisionError):
+ async with p.connection() as conn:
+ pid = conn.info.backend_pid
+ 1 / 0
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_concurrent_filling(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+
+ async def add_time(self, conn):
+ times.append(time() - t0)
+ await add_orig(self, conn)
+
+ add_orig = pool.AsyncConnectionPool._add_to_pool
+ monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time)
+
+ times: List[float] = []
+ t0 = time()
+
+ async with pool.AsyncConnectionPool(dsn, min_size=5, num_workers=2) as p:
+ await p.wait(1.0)
+ want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+ assert len(times) == len(want_times)
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.wait(0.3)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.wait(0.5)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=2) as p:
+ await p.wait(0.3)
+ await p.wait(0.0001) # idempotent
+
+
+async def test_wait_closed(dsn):
+ async with pool.AsyncConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(pool.PoolClosed):
+ await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=1, num_workers=1
+ ) as p:
+ await p.wait(0.2)
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=1, num_workers=1
+ ) as p:
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+ inits = 0
+
+ async def configure(conn):
+ nonlocal inits
+ inits += 1
+ async with conn.transaction():
+ await conn.execute("set default_transaction_read_only to on")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ await p.wait(timeout=1.0)
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+ await conn.close()
+
+ async with p.connection() as conn:
+ assert inits == 2
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ await conn.execute("select 1")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
+ with pytest.raises(pool.PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+
+ await p.wait()
+ assert resets == 1
+
+ async with p.connection() as conn:
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+
+ await p.wait()
+ assert resets == 2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.info.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.info.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue(dsn):
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ await p.wait()
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+ async def worker(t, ev=None):
+ try:
+ async with p.connection():
+ if ev:
+ ev.set()
+ await asyncio.sleep(t)
+ except pool.TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p:
+ await p.wait()
+ ev = asyncio.Event()
+ create_task(worker(0.3, ev))
+ await ev.wait()
+
+ ts = [create_task(worker(0.1)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], pool.TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout(dsn):
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ results: List[int] = []
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_queue_timeout_override(dsn):
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ pid = conn.info.backend_pid
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_broken_reconnect(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ async with p.connection() as conn:
+ pid1 = conn.info.backend_pid
+ await conn.close()
+
+ async with p.connection() as conn2:
+ pid2 = conn2.info.backend_pid
+
+ assert pid1 != pid2
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+ cur = await conn2.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid == pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+async def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+ pid = conn.info.backend_pid
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+ assert conn.info.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await p.getconn()
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.info.backend_pid
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.info.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.info.backend_pid != pid
+ assert conn2.info.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ assert p._sched_runner and not p._sched_runner.done()
+ assert p._workers
+ workers = p._workers[:]
+ for t in workers:
+ assert not t.done()
+
+ await p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert t.done()
+
+
+async def test_putconn_no_pool(aconn_cls, dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
+ conn = await aconn_cls.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+
+ await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p1:
+ async with pool.AsyncConnectionPool(dsn, min_size=1) as p2:
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(pool.PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+
+ async with p.connection() as conn:
+ pass
+ assert not conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+async def test_closed_queue(dsn):
+ async def w1():
+ async with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ await e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ async def w2():
+ try:
+ async with p.connection():
+ pass # unexpected
+ except pool.PoolClosed:
+ success.append("w2")
+
+ e1 = asyncio.Event()
+ e2 = asyncio.Event()
+
+ p = pool.AsyncConnectionPool(dsn, min_size=1)
+ await p.wait()
+ success: List[str] = []
+
+ t1 = create_task(w1())
+ # Wait until w1 has received a connection
+ await e1.wait()
+
+ t2 = create_task(w2())
+ # Wait until w2 is in the queue
+ await ensure_waiting(p)
+ await p.close()
+
+ # Wait for the workers to finish
+ e2.set()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+ p = pool.AsyncConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(pool.PoolClosed):
+ await p.getconn()
+
+ with pytest.raises(pool.PoolClosed, match="is not open yet"):
+ async with p.connection():
+ pass
+
+ await p.open()
+ try:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+ with pytest.raises(pool.PoolClosed, match="is already closed"):
+ await p.getconn()
+
+
+async def test_open_context(dsn):
+ p = pool.AsyncConnectionPool(dsn, open=False)
+ assert p.closed
+
+ async with p:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+async def test_open_no_op(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ try:
+ assert not p.closed
+ await p.open()
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_open_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ await p.open(wait=True, timeout=0.3)
+ finally:
+ await p.close()
+
+ p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
+ try:
+ await p.open(wait=True, timeout=0.5)
+ finally:
+ await p.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_open_as_wait(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.open(wait=True, timeout=0.3)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
+ await p.open(wait=True, timeout=0.5)
+
+
+async def test_reopen(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ await p.close()
+ assert p._sched_runner is None
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ await p.open()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize(
+ "min_size, want_times",
+ [
+ (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]),
+ (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]),
+ ],
+)
+async def test_grow(dsn, monkeypatch, min_size, want_times):
+ delay_connection(monkeypatch, 0.1)
+
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1 from pg_sleep(0.25)")
+ t1 = time()
+ results.append((n, t1 - t0))
+
+ async with pool.AsyncConnectionPool(
+ dsn, min_size=min_size, max_size=4, num_workers=3
+ ) as p:
+ await p.wait(1.0)
+ ts = []
+ results: List[Tuple[int, float]] = []
+
+ ts = [create_task(worker(i)) for i in range(len(want_times))]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_shrink(dsn, monkeypatch):
+
+ from psycopg_pool.pool_async import ShrinkPool
+
+ results: List[Tuple[int, int]] = []
+
+ async def run_hacked(self, pool):
+ n0 = pool._nconns
+ await orig_run(self, pool)
+ n1 = pool._nconns
+ results.append((n0, n1))
+
+ orig_run = ShrinkPool._run
+ monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
+
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.1)")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
+ await p.wait(5.0)
+ assert p.max_idle == 0.2
+
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(1)
+
+ assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+async def test_reconnect(proxy, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+ caplog.clear()
+ proxy.start()
+ async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=1) as p:
+ await p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await asyncio.sleep(1.0)
+ proxy.start()
+ await p.wait()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ assert "BAD" in caplog.messages[0]
+ times = [rec.created for rec in caplog.records]
+ assert times[1] - times[0] < 0.05
+ deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+ assert len(deltas) == 3
+ want = 0.1
+ for delta in deltas:
+ assert delta == pytest.approx(want, 0.05), deltas
+ want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_reconnect_failure(proxy):
+ proxy.start()
+
+ t1 = None
+
+ def failed(pool):
+ assert pool.name == "this-one"
+ nonlocal t1
+ t1 = time()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn,
+ name="this-one",
+ min_size=1,
+ reconnect_timeout=1.0,
+ reconnect_failed=failed,
+ ) as p:
+ await p.wait(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ t0 = time()
+ await asyncio.sleep(1.5)
+ assert t1
+ assert t1 - t0 == pytest.approx(1.0, 0.1)
+ assert p._nconns == 0
+
+ proxy.start()
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ t1 = time()
+ assert t1 - t0 < 0.2
+
+
+@pytest.mark.slow
+async def test_reconnect_after_grow_failed(proxy):
+ # Retry reconnection after a failed connection attempt has put the pool
+ # in grow mode. See issue #370.
+ proxy.stop()
+
+ ev = asyncio.Event()
+
+ def failed(pool):
+ ev.set()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ await asyncio.wait_for(ev.wait(), 2.0)
+
+ with pytest.raises(pool.PoolTimeout):
+ async with p.connection(timeout=0.5) as conn:
+ pass
+
+ ev.clear()
+ await asyncio.wait_for(ev.wait(), 2.0)
+
+ proxy.start()
+
+ async with p.connection(timeout=2) as conn:
+ await conn.execute("select 1")
+
+ await p.wait(timeout=3.0)
+ assert len(p._pool) == p.min_size == 4
+
+
+@pytest.mark.slow
+async def test_refill_on_check(proxy):
+ proxy.start()
+ ev = asyncio.Event()
+
+ def failed(pool):
+ ev.set()
+
+ async with pool.AsyncConnectionPool(
+ proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+ ) as p:
+ # The pool is full
+ await p.wait(timeout=2)
+
+ # Break all the connection
+ proxy.stop()
+
+ # Checking the pool will empty it
+ await p.check()
+ await asyncio.wait_for(ev.wait(), 2.0)
+ assert len(p._pool) == 0
+
+ # Allow to connect again
+ proxy.start()
+
+ # Make sure that check has refilled the pool
+ await p.check()
+ await p.wait(timeout=2)
+ assert len(p._pool) == 4
+
+
+@pytest.mark.slow
+async def test_uniform_use(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=4) as p:
+ counts = Counter[int]()
+ for i in range(8):
+ async with p.connection() as conn:
+ await asyncio.sleep(0.1)
+ counts[id(conn)] += 1
+
+ assert len(counts) == 4
+ assert set(counts.values()) == set([2])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_resize(dsn):
+ async def sampler():
+ await asyncio.sleep(0.05) # ensure sampling happens after shrink check
+ while True:
+ await asyncio.sleep(0.2)
+ if p.closed:
+ break
+ size.append(len(p._pool))
+
+ async def client(t):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(%s)", [t])
+
+ size: List[int] = []
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
+ s = create_task(sampler())
+
+ await asyncio.sleep(0.3)
+
+ c = create_task(client(0.4))
+
+ await asyncio.sleep(0.2)
+ await p.resize(4)
+ assert p.min_size == 4
+ assert p.max_size == 4
+
+ await asyncio.sleep(0.4)
+ await p.resize(2)
+ assert p.min_size == 2
+ assert p.max_size == 2
+
+ await asyncio.sleep(0.6)
+
+ await asyncio.gather(s, c)
+ assert size == [2, 1, 3, 4, 3, 2, 2]
+
+
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)])
+async def test_bad_resize(dsn, min_size, max_size):
+ async with pool.AsyncConnectionPool() as p:
+ with pytest.raises(ValueError):
+ await p.resize(min_size=min_size, max_size=max_size)
+
+
+async def test_jitter():
+ rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
+ assert 27 <= min(rnds) <= 28
+ assert 35 < max(rnds) < 36
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+async def test_max_lifetime(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
+ await asyncio.sleep(0.1)
+ pids = []
+ for i in range(5):
+ async with p.connection() as conn:
+ pids.append(conn.info.backend_pid)
+ await asyncio.sleep(0.2)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+@pytest.mark.crdb_skip("backend pid")
+async def test_check(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ async with pool.AsyncConnectionPool(dsn, min_size=4) as p:
+ await p.wait(1.0)
+ async with p.connection() as conn:
+ pid = conn.info.backend_pid
+
+ await p.wait(1.0)
+ pids = set(conn.info.backend_pid for conn in p._pool)
+ assert pid in pids
+ await conn.close()
+
+ assert len(caplog.records) == 0
+ await p.check()
+ assert len(caplog.records) == 1
+ await p.wait(1.0)
+ pids2 = set(conn.info.backend_pid for conn in p._pool)
+ assert len(pids & pids2) == 3
+ assert pid not in pids2
+
+
+async def test_check_idle(dsn):
+ async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
+ await p.wait(1.0)
+ await p.check()
+ async with p.connection() as conn:
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4) as p:
+ await p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 2
+ assert stats["pool_available"] == 2
+ assert stats["requests_waiting"] == 0
+
+ ts = [create_task(worker(i)) for i in range(3)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ await p.wait(2.0)
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 2
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn):
+ async def worker(n):
+ try:
+ async with p.connection(timeout=0.3) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ except pool.PoolTimeout:
+ pass
+
+ async with pool.AsyncConnectionPool(dsn, min_size=3) as p:
+ await p.wait(2.0)
+
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.gather(*ts)
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ async with p.connection() as conn:
+ await conn.close()
+ await p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ async with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=3) as p:
+ await p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 3
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 580 <= stats["connections_ms"] < 1200
+
+ proxy.stop()
+ await p.check()
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ assert stats["connections_num"] > 3
+ assert stats["connections_errors"] > 0
+ assert stats["connections_lost"] == 3
+
+
+@pytest.mark.slow
+async def test_spike(dsn, monkeypatch):
+ # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/
+ # documents/Welcome-To-The-Jungle.md
+ delay_connection(monkeypatch, 0.15)
+
+ async def worker():
+ async with p.connection():
+ await asyncio.sleep(0.002)
+
+ async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p:
+ await p.wait()
+
+ ts = [create_task(worker()) for i in range(50)]
+ await asyncio.gather(*ts)
+ await p.wait()
+
+ assert len(p._pool) < 7
+
+
+async def test_debug_deadlock(dsn):
+ # https://github.com/psycopg/psycopg/issues/230
+ logger = logging.getLogger("psycopg")
+ handler = logging.StreamHandler()
+ old_level = logger.level
+ logger.setLevel(logging.DEBUG)
+ handler.setLevel(logging.DEBUG)
+ logger.addHandler(handler)
+ try:
+ async with pool.AsyncConnectionPool(dsn, min_size=4, open=True) as p:
+ await p.wait(timeout=2)
+ finally:
+ logger.removeHandler(handler)
+ logger.setLevel(old_level)
+
+
+def delay_connection(monkeypatch, sec):
+ """
+ Return a _connect_gen function delayed by the amount of seconds
+ """
+
+ async def connect_delay(*args, **kwargs):
+ t0 = time()
+ rv = await connect_orig(*args, **kwargs)
+ t1 = time()
+ await asyncio.sleep(max(0, sec - (t1 - t0)))
+ return rv
+
+ connect_orig = psycopg.AsyncConnection.connect
+ monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay)
+
+
+async def ensure_waiting(p, num=1):
+ while len(p._waiting) < num:
+ await asyncio.sleep(0)
diff --git a/tests/pool/test_pool_async_noasyncio.py b/tests/pool/test_pool_async_noasyncio.py
new file mode 100644
index 0000000..f6e34e4
--- /dev/null
+++ b/tests/pool/test_pool_async_noasyncio.py
@@ -0,0 +1,78 @@
+# These tests relate to AsyncConnectionPool, but are not marked asyncio
+# because they rely on the pool initialization outside the asyncio loop.
+
+import asyncio
+
+import pytest
+
+from ..utils import gc_collect
+
+try:
+ import psycopg_pool as pool
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+
+@pytest.mark.slow
+def test_reconnect_after_max_lifetime(dsn, asyncio_run):
+ # See issue #219, pool created before the loop.
+ p = pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2, open=False)
+
+ async def test():
+ try:
+ await p.open()
+ ns = []
+ for i in range(5):
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ ns.append(await cur.fetchone())
+ await asyncio.sleep(0.2)
+ assert len(ns) == 5
+ finally:
+ await p.close()
+
+ asyncio_run(asyncio.wait_for(test(), timeout=2.0))
+
+
+@pytest.mark.slow
+def test_working_created_before_loop(dsn, asyncio_run):
+ p = pool.AsyncNullConnectionPool(dsn, open=False)
+
+ async def test():
+ try:
+ await p.open()
+ ns = []
+ for i in range(5):
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ ns.append(await cur.fetchone())
+ await asyncio.sleep(0.2)
+ assert len(ns) == 5
+ finally:
+ await p.close()
+
+ asyncio_run(asyncio.wait_for(test(), timeout=2.0))
+
+
+def test_cant_create_open_outside_loop(dsn):
+ with pytest.raises(RuntimeError):
+ pool.AsyncConnectionPool(dsn, open=True)
+
+
+@pytest.fixture
+def asyncio_run(recwarn):
+ """Fixture reuturning asyncio.run, but managing resources at exit.
+
+ In certain runs, fd objects are leaked and the error will only be caught
+ downstream, by some innocent test calling gc_collect().
+ """
+ recwarn.clear()
+ try:
+ yield asyncio.run
+ finally:
+ gc_collect()
+ if recwarn:
+ warn = recwarn.pop(ResourceWarning)
+ assert "unclosed event loop" in str(warn.message)
+ assert not recwarn
diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py
new file mode 100644
index 0000000..b3d2572
--- /dev/null
+++ b/tests/pool/test_sched.py
@@ -0,0 +1,154 @@
+import logging
+from time import time, sleep
+from functools import partial
+from threading import Thread
+
+import pytest
+
+try:
+ from psycopg_pool.sched import Scheduler
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.timing]
+
+
+@pytest.mark.slow
+def test_sched():
+ s = Scheduler()
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+ s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+def test_sched_thread():
+ s = Scheduler()
+ t = Thread(target=s.run, daemon=True)
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.2)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.2)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.2)
+
+
+@pytest.mark.slow
+def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ s = Scheduler()
+ t = Thread(target=s.run, daemon=True)
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ def error():
+ 1 / 0
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, None)
+ s.enter(0.3, partial(worker, 2))
+ s.enter(0.2, error)
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_empty_queue_timeout():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = Thread(target=s.run)
+ t.start()
+ sleep(0.5)
+ s.enter(0.5, None)
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+def test_first_task_rescheduling():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ s.enter(0.4, lambda: None)
+ t = Thread(target=s.run)
+ t.start()
+ s.enter(0.6, None) # this task doesn't trigger a reschedule
+ sleep(0.1)
+ s.enter(0.1, lambda: None) # this triggers a reschedule
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.2), times
diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py
new file mode 100644
index 0000000..492d620
--- /dev/null
+++ b/tests/pool/test_sched_async.py
@@ -0,0 +1,159 @@
+import asyncio
+import logging
+from time import time
+from functools import partial
+
+import pytest
+
+from psycopg._compat import create_task
+
+try:
+ from psycopg_pool.sched import AsyncScheduler
+except ImportError:
+ # Tests should have been skipped if the package is not available
+ pass
+
+pytestmark = [pytest.mark.asyncio, pytest.mark.timing]
+
+
+@pytest.mark.slow
+async def test_sched():
+ s = AsyncScheduler()
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+ await s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_task():
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.2)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.2)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.2)
+
+
+@pytest.mark.slow
+async def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ async def error():
+ 1 / 0
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, None)
+ await s.enter(0.3, partial(worker, 2))
+ await s.enter(0.2, error)
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_empty_queue_timeout():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = create_task(s.run())
+ await asyncio.sleep(0.5)
+ await s.enter(0.5, None)
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+async def test_first_task_rescheduling():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ setattr(s._event, "wait", wait_logging)
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ async def noop():
+ pass
+
+ await s.enter(0.4, noop)
+ t = create_task(s.run())
+ await s.enter(0.6, None) # this task doesn't trigger a reschedule
+ await asyncio.sleep(0.1)
+ await s.enter(0.1, noop) # this triggers a reschedule
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.2), times
diff --git a/tests/pq/__init__.py b/tests/pq/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/pq/__init__.py
diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py
new file mode 100644
index 0000000..2c3de98
--- /dev/null
+++ b/tests/pq/test_async.py
@@ -0,0 +1,210 @@
+from select import select
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg.generators import execute
+
+
+def execute_wait(pgconn):
+ return psycopg.waiting.wait(execute(pgconn), pgconn.socket)
+
+
+def test_send_query(pgconn):
+ # This test shows how to process an async query in all its glory
+ pgconn.nonblocking = 1
+
+ # Long query to make sure we have to wait on send
+ pgconn.send_query(
+ b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;"
+ % (b"x" * 1_000_000)
+ )
+
+ # send loop
+ waited_on_send = 0
+ while True:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ waited_on_send += 1
+
+ rl, wl, xl = select([pgconn.socket], [pgconn.socket], [])
+ assert not (rl and wl)
+ if wl:
+ continue # call flush again()
+ if rl:
+ pgconn.consume_input()
+ continue
+
+ # TODO: this check is not reliable, it fails on travis sometimes
+ # assert waited_on_send
+
+ # read loop
+ results = []
+ while True:
+ pgconn.consume_input()
+ if pgconn.is_busy():
+ select([pgconn.socket], [], [])
+ continue
+ res = pgconn.get_result()
+ if res is None:
+ break
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ results.append(res)
+
+ assert len(results) == 2
+ assert results[0].nfields == 1
+ assert results[0].fname(0) == b"f"
+ assert results[0].get_value(0, 0) == b"x"
+ assert results[1].nfields == 1
+ assert results[1].fname(0) == b"foo"
+ assert results[1].get_value(0, 0) == b"1"
+
+
+def test_send_query_compact_test(pgconn):
+ # Like the above test but use psycopg facilities for compactness
+ pgconn.send_query(
+ b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;"
+ % (b"x" * 1_000_000)
+ )
+ results = execute_wait(pgconn)
+
+ assert len(results) == 2
+ assert results[0].nfields == 1
+ assert results[0].fname(0) == b"f"
+ assert results[0].get_value(0, 0) == b"x"
+ assert results[1].nfields == 1
+ assert results[1].fname(0) == b"foo"
+ assert results[1].get_value(0, 0) == b"1"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query(b"select 1")
+
+
+def test_single_row_mode(pgconn):
+ pgconn.send_query(b"select generate_series(1,2)")
+ pgconn.set_single_row_mode()
+
+ results = execute_wait(pgconn)
+ assert len(results) == 3
+
+ res = results[0]
+ assert res.status == pq.ExecStatus.SINGLE_TUPLE
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"1"
+
+ res = results[1]
+ assert res.status == pq.ExecStatus.SINGLE_TUPLE
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"2"
+
+ res = results[2]
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.ntuples == 0
+
+
+def test_send_query_params(pgconn):
+ pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query_params(b"select $1", [b"1"])
+
+
+def test_send_prepare(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+ (res,) = execute_wait(pgconn)
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+
+
+def test_send_prepare_types(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"prep", [b"3", b"5"])
+ (res,) = execute_wait(pgconn)
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_send_prepared_binary_in(pgconn):
+ val = b"foo\00bar"
+ pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1])
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_send_prepared_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ pgconn.send_prepare(b"", b"select $1::bytea")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt)
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+def test_send_describe_prepared(pgconn):
+ pgconn.send_prepare(b"prep", b"select $1::int8 + $2::int8 as fld")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_describe_prepared(b"prep")
+ (res,) = execute_wait(pgconn)
+ assert res.nfields == 1
+ assert res.ntuples == 0
+ assert res.fname(0) == b"fld"
+ assert res.ftype(0) == 20
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_describe_prepared(b"prep")
+
+
+@pytest.mark.crdb_skip("server-side cursor")
+def test_send_describe_portal(pgconn):
+ res = pgconn.exec_(
+ b"""
+ begin;
+ declare cur cursor for select * from generate_series(1,10) foo;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ pgconn.send_describe_portal(b"cur")
+ (res,) = execute_wait(pgconn)
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ assert res.nfields == 1
+ assert res.fname(0) == b"foo"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.send_describe_portal(b"cur")
diff --git a/tests/pq/test_conninfo.py b/tests/pq/test_conninfo.py
new file mode 100644
index 0000000..64d8b8f
--- /dev/null
+++ b/tests/pq/test_conninfo.py
@@ -0,0 +1,48 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_defaults(monkeypatch):
+ monkeypatch.setenv("PGPORT", "15432")
+ defs = pq.Conninfo.get_defaults()
+ assert len(defs) > 20
+ port = [d for d in defs if d.keyword == b"port"][0]
+ assert port.envvar == b"PGPORT"
+ assert port.compiled == b"5432"
+ assert port.val == b"15432"
+ assert port.label == b"Database-Port"
+ assert port.dispchar == b""
+ assert port.dispsize == 6
+
+
+@pytest.mark.libpq(">= 10")
+def test_conninfo_parse():
+ infos = pq.Conninfo.parse(
+ b"postgresql://host1:123,host2:456/somedb"
+ b"?target_session_attrs=any&application_name=myapp"
+ )
+ info = {i.keyword: i.val for i in infos if i.val is not None}
+ assert info[b"host"] == b"host1,host2"
+ assert info[b"port"] == b"123,456"
+ assert info[b"dbname"] == b"somedb"
+ assert info[b"application_name"] == b"myapp"
+
+
+@pytest.mark.libpq("< 10")
+def test_conninfo_parse_96():
+ conninfo = pq.Conninfo.parse(
+ b"postgresql://other@localhost/otherdb"
+ b"?connect_timeout=10&application_name=myapp"
+ )
+ info = {i.keyword: i.val for i in conninfo if i.val is not None}
+ assert info[b"host"] == b"localhost"
+ assert info[b"dbname"] == b"otherdb"
+ assert info[b"application_name"] == b"myapp"
+
+
+def test_conninfo_parse_bad():
+ with pytest.raises(psycopg.OperationalError) as e:
+ pq.Conninfo.parse(b"bad_conninfo=")
+ assert "bad_conninfo" in str(e.value)
diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py
new file mode 100644
index 0000000..383d272
--- /dev/null
+++ b/tests/pq/test_copy.py
@@ -0,0 +1,174 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+pytestmark = pytest.mark.crdb_skip("copy")
+
+sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+
+sample_tabledef = "col1 int primary key, col2 int, data text"
+
+sample_text = b"""\
+10\t20\thello
+40\t\\N\tworld
+"""
+
+sample_binary_value = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n")
+]
+
+sample_binary = b"".join(sample_binary_rows)
+
+
+def test_put_data_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_data(b"wat")
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_data(b"wat")
+
+
+def test_put_end_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_end()
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.put_copy_end()
+
+
+def test_copy_out(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_(
+ b"select min(col1), max(col1), count(*), max(length(data)) from copy_in"
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+ assert res.get_value(0, 1) == b"199"
+ assert res.get_value(0, 2) == b"200"
+ assert res.get_value(0, 3) == b"199"
+
+
+def test_copy_out_err(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\thardly a number\tnope
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end()
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"hardly a number" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def test_copy_out_error_end(pgconn):
+ ensure_table(pgconn, sample_tabledef)
+ res = pgconn.exec_(b"copy copy_in from stdin")
+ assert res.status == pq.ExecStatus.COPY_IN
+
+ for i in range(10):
+ data = []
+ for j in range(20):
+ data.append(
+ f"""\
+{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)}
+"""
+ )
+ rv = pgconn.put_copy_data("".join(data).encode("ascii"))
+ assert rv > 0
+
+ rv = pgconn.put_copy_end(b"nuttengoggenio")
+ assert rv > 0
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"nuttengoggenio" in res.error_message
+
+ res = pgconn.exec_(b"select count(*) from copy_in")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"0"
+
+
+def test_get_data_no_copy(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_copy_data(0)
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_copy_data(0)
+
+
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_out_read(pgconn, format):
+ stmt = f"copy ({sample_values}) to stdout (format {format.name})"
+ res = pgconn.exec_(stmt.encode("ascii"))
+ assert res.status == pq.ExecStatus.COPY_OUT
+ assert res.binary_tuples == format
+
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ for row in want:
+ nbytes, data = pgconn.get_copy_data(0)
+ assert nbytes == len(data)
+ assert data == row
+
+ assert pgconn.get_copy_data(0) == (-1, b"")
+
+ res = pgconn.get_result()
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
+def ensure_table(pgconn, tabledef, name="copy_in"):
+ pgconn.exec_(f"drop table if exists {name}".encode("ascii"))
+ pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii"))
diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py
new file mode 100644
index 0000000..ad88d8a
--- /dev/null
+++ b/tests/pq/test_escaping.py
@@ -0,0 +1,188 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+from ..fix_crdb import crdb_scs_off
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b"''"),
+ (b"hello", b"'hello'"),
+ (b"foo'bar", b"'foo''bar'"),
+ (b"foo\\bar", b" E'foo\\\\bar'"),
+ ],
+)
+def test_escape_literal(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_literal(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_literal_1char(pgconn, scs):
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ esc = pq.Escaping(pgconn)
+ special = {b"'": b"''''", b"\\": b" E'\\\\'"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_literal(data)
+ exp = special.get(data) or b"'%s'" % data
+ assert rv == exp
+
+
+def test_escape_literal_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_literal(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_literal(b"hi")
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b'""'),
+ (b"hello", b'"hello"'),
+ (b'foo"bar', b'"foo""bar"'),
+ (b"foo\\bar", b'"foo\\bar"'),
+ ],
+)
+def test_escape_identifier(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_identifier(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_identifier_1char(pgconn, scs):
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ esc = pq.Escaping(pgconn)
+ special = {b'"': b'""""', b"\\": b'"\\"'}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_identifier(data)
+ exp = special.get(data) or b'"%s"' % data
+ assert rv == exp
+
+
+def test_escape_identifier_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_identifier(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_identifier(b"hi")
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b""),
+ (b"hello", b"hello"),
+ (b"foo'bar", b"foo''bar"),
+ (b"foo\\bar", b"foo\\bar"),
+ ],
+)
+def test_escape_string(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_string(data)
+ assert out == want
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_escape_string_1char(pgconn, scs):
+ esc = pq.Escaping(pgconn)
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_string(data)
+ exp = special.get(data) or b"%s" % data
+ assert rv == exp
+
+
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b""),
+ (b"hello", b"hello"),
+ (b"foo'bar", b"foo''bar"),
+ # This libpq function behaves unpredictably when not passed a conn
+ (b"foo\\bar", (b"foo\\\\bar", b"foo\\bar")),
+ ],
+)
+def test_escape_string_noconn(data, want):
+ esc = pq.Escaping()
+ out = esc.escape_string(data)
+ if isinstance(want, bytes):
+ assert out == want
+ else:
+ assert out in want
+
+
+def test_escape_string_badconn(pgconn):
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_string(b"hi")
+
+
+def test_escape_string_badenc(pgconn):
+ res = pgconn.exec_(b"set client_encoding to 'UTF8'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ data = "\u20ac".encode()[:-1]
+ esc = pq.Escaping(pgconn)
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_string(data)
+
+
+@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"])
+def test_escape_bytea(pgconn, data):
+ exp = rb"\x" + b"".join(b"%02x" % c for c in data)
+ esc = pq.Escaping(pgconn)
+ rv = esc.escape_bytea(data)
+ assert rv == exp
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.escape_bytea(data)
+
+
+def test_escape_noconn(pgconn):
+ data = bytes(range(256))
+ esc = pq.Escaping()
+ escdata = esc.escape_bytea(data)
+ res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == data
+
+
+def test_escape_1char(pgconn):
+ esc = pq.Escaping(pgconn)
+ for c in range(256):
+ rv = esc.escape_bytea(bytes([c]))
+ exp = rb"\x%02x" % c
+ assert rv == exp
+
+
+@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"])
+def test_unescape_bytea(pgconn, data):
+ enc = rb"\x" + b"".join(b"%02x" % c for c in data)
+ esc = pq.Escaping(pgconn)
+ rv = esc.unescape_bytea(enc)
+ assert rv == data
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ esc.unescape_bytea(data)
diff --git a/tests/pq/test_exec.py b/tests/pq/test_exec.py
new file mode 100644
index 0000000..86c30c0
--- /dev/null
+++ b/tests/pq/test_exec.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_exec_none(pgconn):
+ with pytest.raises(TypeError):
+ pgconn.exec_(None)
+
+
+def test_exec(pgconn):
+ res = pgconn.exec_(b"select 'hel' || 'lo'")
+ assert res.get_value(0, 0) == b"hello"
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_(b"select 'hello'")
+
+
+def test_exec_params(pgconn):
+ res = pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"])
+
+
+def test_exec_params_empty(pgconn):
+ res = pgconn.exec_params(b"select 8::int", [])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_exec_params_types(pgconn):
+ res = pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700, 23])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"8"
+ assert res.ftype(0) == 1700
+ assert res.get_value(0, 1) == b"8"
+ assert res.ftype(1) == 23
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700])
+
+
+def test_exec_params_nulls(pgconn):
+ res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"hi"
+ assert res.get_value(0, 1) == b""
+ assert res.get_value(0, 2) is None
+
+
+def test_exec_params_binary_in(pgconn):
+ val = b"foo\00bar"
+ res = pgconn.exec_params(
+ b"select length($1::bytea), length($2::bytea)",
+ [val, val],
+ param_formats=[0, 1],
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_exec_params_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ res = pgconn.exec_params(
+ b"select $1::bytea", [val], param_formats=[1], result_format=fmt
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+def test_prepare(pgconn):
+ res = pgconn.prepare(b"prep", b"select $1::int + $2::int")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"prep", [b"3", b"5"])
+ assert res.get_value(0, 0) == b"8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.prepare(b"prep", b"select $1::int + $2::int")
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.exec_prepared(b"prep", [b"3", b"5"])
+
+
+def test_prepare_types(pgconn):
+ res = pgconn.prepare(b"prep", b"select $1 + $2", [23, 23])
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"prep", [b"3", b"5"])
+ assert res.get_value(0, 0) == b"8"
+
+
+def test_exec_prepared_binary_in(pgconn):
+ val = b"foo\00bar"
+ res = pgconn.prepare(b"", b"select length($1::bytea), length($2::bytea)")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"", [val, val], param_formats=[0, 1])
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == b"3"
+ assert res.get_value(0, 1) == b"7"
+
+ with pytest.raises(ValueError):
+ pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
+
+
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
+def test_exec_prepared_binary_out(pgconn, fmt, out):
+ val = b"foo\00bar"
+ res = pgconn.prepare(b"", b"select $1::bytea")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt)
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ assert res.get_value(0, 0) == out
+
+
+@pytest.mark.crdb_skip("server-side cursor")
+def test_describe_portal(pgconn):
+ res = pgconn.exec_(
+ b"""
+ begin;
+ declare cur cursor for select * from generate_series(1,10) foo;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.describe_portal(b"cur")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ assert res.nfields == 1
+ assert res.fname(0) == b"foo"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.describe_portal(b"cur")
diff --git a/tests/pq/test_misc.py b/tests/pq/test_misc.py
new file mode 100644
index 0000000..599758f
--- /dev/null
+++ b/tests/pq/test_misc.py
@@ -0,0 +1,83 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+def test_error_message(pgconn):
+ res = pgconn.exec_(b"wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ msg = pq.error_message(pgconn)
+ assert "wat" in msg
+ assert msg == pq.error_message(res)
+ primary = res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+ assert primary.decode("ascii") in msg
+
+ with pytest.raises(TypeError):
+ pq.error_message(None) # type: ignore[arg-type]
+
+ res.clear()
+ assert pq.error_message(res) == "no details available"
+ pgconn.finish()
+ assert "NULL" in pq.error_message(pgconn)
+
+
+@pytest.mark.crdb_skip("encoding")
+def test_error_message_encoding(pgconn):
+ res = pgconn.exec_(b"set client_encoding to latin9")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+
+ res = pgconn.exec_('select 1 from "foo\u20acbar"'.encode("latin9"))
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+
+ msg = pq.error_message(pgconn)
+ assert "foo\u20acbar" in msg
+
+ msg = pq.error_message(res)
+ assert "foo\ufffdbar" in msg
+
+ msg = pq.error_message(res, encoding="latin9")
+ assert "foo\u20acbar" in msg
+
+ msg = pq.error_message(res, encoding="ascii")
+ assert "foo\ufffdbar" in msg
+
+
+def test_make_empty_result(pgconn):
+ pgconn.exec_(b"wat")
+ res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert b"wat" in res.error_message
+
+ pgconn.finish()
+ res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert res.error_message == b""
+
+
+def test_result_set_attrs(pgconn):
+ res = pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
+ assert res.status == pq.ExecStatus.COPY_OUT
+
+ attrs = [
+ pq.PGresAttDesc(b"an_int", 0, 0, 0, 23, 0, 0),
+ pq.PGresAttDesc(b"a_num", 0, 0, 0, 1700, 0, 0),
+ pq.PGresAttDesc(b"a_bin_text", 0, 0, 1, 25, 0, 0),
+ ]
+ res.set_attributes(attrs)
+ assert res.nfields == 3
+
+ assert res.fname(0) == b"an_int"
+ assert res.fname(1) == b"a_num"
+ assert res.fname(2) == b"a_bin_text"
+
+ assert res.fformat(0) == 0
+ assert res.fformat(1) == 0
+ assert res.fformat(2) == 1
+
+ assert res.ftype(0) == 23
+ assert res.ftype(1) == 1700
+ assert res.ftype(2) == 25
+
+ with pytest.raises(psycopg.OperationalError):
+ res.set_attributes(attrs)
diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py
new file mode 100644
index 0000000..0566151
--- /dev/null
+++ b/tests/pq/test_pgconn.py
@@ -0,0 +1,585 @@
+import os
+import sys
+import ctypes
+import logging
+import weakref
+from select import select
+
+import pytest
+
+import psycopg
+from psycopg import pq
+import psycopg.generators
+
+from ..utils import gc_collect
+
+
+def test_connectdb(dsn):
+ conn = pq.PGconn.connect(dsn.encode())
+ assert conn.status == pq.ConnStatus.OK, conn.error_message
+
+
+def test_connectdb_error():
+ conn = pq.PGconn.connect(b"dbname=psycopg_test_not_for_real")
+ assert conn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.parametrize("baddsn", [None, 42])
+def test_connectdb_badtype(baddsn):
+ with pytest.raises(TypeError):
+ pq.PGconn.connect(baddsn)
+
+
+def test_connect_async(dsn):
+ conn = pq.PGconn.connect_start(dsn.encode())
+ conn.nonblocking = 1
+ while True:
+ assert conn.status != pq.ConnStatus.BAD
+ rv = conn.connect_poll()
+ if rv == pq.PollingStatus.OK:
+ break
+ elif rv == pq.PollingStatus.READING:
+ select([conn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [conn.socket], [])
+ else:
+ assert False, rv
+
+ assert conn.status == pq.ConnStatus.OK
+
+ conn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ conn.connect_poll()
+
+
+@pytest.mark.crdb("skip", reason="connects to any db name")
+def test_connect_async_bad(dsn):
+ parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val}
+ parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real"
+ dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items())
+ conn = pq.PGconn.connect_start(dsn)
+ while True:
+ assert conn.status != pq.ConnStatus.BAD, conn.error_message
+ rv = conn.connect_poll()
+ if rv == pq.PollingStatus.FAILED:
+ break
+ elif rv == pq.PollingStatus.READING:
+ select([conn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [conn.socket], [])
+ else:
+ assert False, rv
+
+ assert conn.status == pq.ConnStatus.BAD
+
+
+def test_finish(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.finish()
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.finish()
+ assert pgconn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.slow
+def test_weakref(dsn):
+ conn = pq.PGconn.connect(dsn.encode())
+ w = weakref.ref(conn)
+ conn.finish()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+@pytest.mark.skipif(
+ sys.platform == "win32"
+ and os.environ.get("CI") == "true"
+ and pq.__impl__ != "python",
+ reason="can't figure out how to make ctypes run, don't care",
+)
+def test_pgconn_ptr(pgconn, libpq):
+ assert isinstance(pgconn.pgconn_ptr, int)
+
+ f = libpq.PQserverVersion
+ f.argtypes = [ctypes.c_void_p]
+ f.restype = ctypes.c_int
+ ver = f(pgconn.pgconn_ptr)
+ assert ver == pgconn.server_version
+
+ pgconn.finish()
+ assert pgconn.pgconn_ptr is None
+
+
+def test_info(dsn, pgconn):
+ info = pgconn.info
+ assert len(info) > 20
+ dbname = [d for d in info if d.keyword == b"dbname"][0]
+ assert dbname.envvar == b"PGDATABASE"
+ assert dbname.label == b"Database-Name"
+ assert dbname.dispchar == b""
+ assert dbname.dispsize == 20
+
+ parsed = pq.Conninfo.parse(dsn.encode())
+ # take the name and the user either from params or from env vars
+ name = [
+ o.val or os.environ.get(o.envvar.decode(), "").encode()
+ for o in parsed
+ if o.keyword == b"dbname" and o.envvar
+ ][0]
+ user = [
+ o.val or os.environ.get(o.envvar.decode(), "").encode()
+ for o in parsed
+ if o.keyword == b"user" and o.envvar
+ ][0]
+ assert dbname.val == (name or user)
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.info
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_reset(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.reset()
+ assert pgconn.status == pq.ConnStatus.OK
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset()
+
+ assert pgconn.status == pq.ConnStatus.BAD
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_reset_async(pgconn):
+ assert pgconn.status == pq.ConnStatus.OK
+ pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
+ assert pgconn.status == pq.ConnStatus.BAD
+ pgconn.reset_start()
+ while True:
+ rv = pgconn.reset_poll()
+ if rv == pq.PollingStatus.READING:
+ select([pgconn.socket], [], [])
+ elif rv == pq.PollingStatus.WRITING:
+ select([], [pgconn.socket], [])
+ else:
+ break
+
+ assert rv == pq.PollingStatus.OK
+ assert pgconn.status == pq.ConnStatus.OK
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset_start()
+
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.reset_poll()
+
+
+def test_ping(dsn):
+ rv = pq.PGconn.ping(dsn.encode())
+ assert rv == pq.Ping.OK
+
+ rv = pq.PGconn.ping(b"port=9999")
+ assert rv == pq.Ping.NO_RESPONSE
+
+
+def test_db(pgconn):
+ name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0]
+ assert pgconn.db == name
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.db
+
+
+def test_user(pgconn):
+ user = [o.val for o in pgconn.info if o.keyword == b"user"][0]
+ assert pgconn.user == user
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.user
+
+
+def test_password(pgconn):
+ # not in info
+ assert isinstance(pgconn.password, bytes)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.password
+
+
+def test_host(pgconn):
+ # might be not in info
+ assert isinstance(pgconn.host, bytes)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.host
+
+
+@pytest.mark.libpq(">= 12")
+def test_hostaddr(pgconn):
+ # not in info
+ assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.hostaddr
+
+
+@pytest.mark.libpq("< 12")
+def test_hostaddr_missing(pgconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.hostaddr
+
+
+def test_port(pgconn):
+ port = [o.val for o in pgconn.info if o.keyword == b"port"][0]
+ assert pgconn.port == port
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.port
+
+
+@pytest.mark.libpq("< 14")
+def test_tty(pgconn):
+ tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0]
+ assert pgconn.tty == tty
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.tty
+
+
+@pytest.mark.libpq(">= 14")
+def test_tty_noop(pgconn):
+ assert not any(o.val for o in pgconn.info if o.keyword == b"tty")
+ assert pgconn.tty == b""
+
+
+def test_transaction_status(pgconn):
+ assert pgconn.transaction_status == pq.TransactionStatus.IDLE
+ pgconn.exec_(b"begin")
+ assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
+ pgconn.send_query(b"select 1")
+ assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ psycopg.waiting.wait(psycopg.generators.execute(pgconn), pgconn.socket)
+ assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
+ pgconn.finish()
+ assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN
+
+
+def test_parameter_status(dsn, monkeypatch):
+ monkeypatch.setenv("PGAPPNAME", "psycopg tests")
+ pgconn = pq.PGconn.connect(dsn.encode())
+ assert pgconn.parameter_status(b"application_name") == b"psycopg tests"
+ assert pgconn.parameter_status(b"wat") is None
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.parameter_status(b"application_name")
+
+
+@pytest.mark.crdb_skip("encoding")
+def test_encoding(pgconn):
+ res = pgconn.exec_(b"set client_encoding to latin1")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert pgconn.parameter_status(b"client_encoding") == b"LATIN1"
+
+ res = pgconn.exec_(b"set client_encoding to 'utf-8'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert pgconn.parameter_status(b"client_encoding") == b"UTF8"
+
+ res = pgconn.exec_(b"set client_encoding to wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ assert pgconn.parameter_status(b"client_encoding") == b"UTF8"
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.parameter_status(b"client_encoding")
+
+
+def test_protocol_version(pgconn):
+ assert pgconn.protocol_version == 3
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.protocol_version
+
+
+def test_server_version(pgconn):
+ assert pgconn.server_version >= 90400
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.server_version
+
+
+def test_socket(pgconn):
+ socket = pgconn.socket
+ assert socket > 0
+ pgconn.exec_(f"select pg_terminate_backend({pgconn.backend_pid})".encode())
+ # TODO: on my box it raises OperationalError as it should. Not on Travis,
+ # so let's see if at least an ok value comes out of it.
+ try:
+ assert pgconn.socket == socket
+ except psycopg.OperationalError:
+ pass
+
+
+def test_error_message(pgconn):
+ assert pgconn.error_message == b""
+ res = pgconn.exec_(b"wat")
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ msg = pgconn.error_message
+ assert b"wat" in msg
+ pgconn.finish()
+ assert b"NULL" in pgconn.error_message # TODO: i10n?
+
+
+def test_backend_pid(pgconn):
+ assert isinstance(pgconn.backend_pid, int)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.backend_pid
+
+
+def test_needs_password(pgconn):
+ # assume connection worked so an eventually needed password wasn't missing
+ assert pgconn.needs_password is False
+ pgconn.finish()
+ pgconn.needs_password
+
+
+def test_used_password(pgconn, dsn, monkeypatch):
+ assert isinstance(pgconn.used_password, bool)
+
+ # Assume that if a password was passed then it was needed.
+ # Note that the server may still need a password passed via pgpass
+ # so it may be that has_password is false but still a password was
+ # requested by the server and passed by libpq.
+ info = pq.Conninfo.parse(dsn.encode())
+ has_password = (
+ "PGPASSWORD" in os.environ
+ or [i for i in info if i.keyword == b"password"][0].val is not None
+ )
+ if has_password:
+ assert pgconn.used_password
+
+ pgconn.finish()
+ pgconn.used_password
+
+
+def test_ssl_in_use(pgconn):
+ assert isinstance(pgconn.ssl_in_use, bool)
+
+ # If connecting via socket then ssl is not in use
+ if pgconn.host.startswith(b"/"):
+ assert not pgconn.ssl_in_use
+ else:
+ sslmode = [i.val for i in pgconn.info if i.keyword == b"sslmode"][0]
+ if sslmode not in (b"disable", b"allow", b"prefer"):
+ assert pgconn.ssl_in_use
+
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.ssl_in_use
+
+
+def test_set_single_row_mode(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.set_single_row_mode()
+
+ pgconn.send_query(b"select 1")
+ pgconn.set_single_row_mode()
+
+
+def test_cancel(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.cancel()
+ cancel.cancel()
+ pgconn.finish()
+ cancel.cancel()
+ with pytest.raises(psycopg.OperationalError):
+ pgconn.get_cancel()
+
+
+def test_cancel_free(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.free()
+ with pytest.raises(psycopg.OperationalError):
+ cancel.cancel()
+ cancel.free()
+
+
+@pytest.mark.crdb_skip("notify")
+def test_notify(pgconn):
+ assert pgconn.notifies() is None
+
+ pgconn.exec_(b"listen foo")
+ pgconn.exec_(b"listen bar")
+ pgconn.exec_(b"notify foo, '1'")
+ pgconn.exec_(b"notify bar, '2'")
+ pgconn.exec_(b"notify foo, '3'")
+
+ n = pgconn.notifies()
+ assert n.relname == b"foo"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"1"
+
+ n = pgconn.notifies()
+ assert n.relname == b"bar"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"2"
+
+ n = pgconn.notifies()
+ assert n.relname == b"foo"
+ assert n.be_pid == pgconn.backend_pid
+ assert n.extra == b"3"
+
+ assert pgconn.notifies() is None
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_nohandler(pgconn):
+ pgconn.exec_(b"set client_min_messages to notice")
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice(pgconn):
+ msgs = []
+
+ def callback(res):
+ assert res.status == pq.ExecStatus.NONFATAL_ERROR
+ msgs.append(res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY))
+
+ pgconn.exec_(b"set client_min_messages to notice")
+ pgconn.notice_handler = callback
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert msgs and msgs[0] == b"hello notice"
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_error(pgconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ def callback(res):
+ raise Exception("hello error")
+
+ pgconn.exec_(b"set client_min_messages to notice")
+ pgconn.notice_handler = callback
+ res = pgconn.exec_(
+ b"do $$begin raise notice 'hello notice'; end$$ language plpgsql"
+ )
+
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello error" in rec.message
+
+
+@pytest.mark.libpq("< 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace_pre14(pgconn, tmp_path):
+ tracef = tmp_path / "trace"
+ with tracef.open("w") as f:
+ pgconn.trace(f.fileno())
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.set_trace_flags(0)
+ pgconn.exec_(b"select 1")
+ pgconn.untrace()
+ pgconn.exec_(b"select 2")
+ traces = tracef.read_text()
+ assert "select 1" in traces
+ assert "select 2" not in traces
+
+
+@pytest.mark.libpq(">= 14")
+@pytest.mark.skipif("sys.platform != 'linux'")
+def test_trace(pgconn, tmp_path):
+ tracef = tmp_path / "trace"
+ with tracef.open("w") as f:
+ pgconn.trace(f.fileno())
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ pgconn.exec_(b"select 1::int4 as foo")
+ pgconn.untrace()
+ pgconn.exec_(b"select 2::int4 as foo")
+ traces = [line.split("\t") for line in tracef.read_text().splitlines()]
+ assert traces == [
+ ["F", "26", "Query", ' "select 1::int4 as foo"'],
+ ["B", "28", "RowDescription", ' 1 "foo" NNNN 0 NNNN 4 -1 0'],
+ ["B", "11", "DataRow", " 1 1 '1'"],
+ ["B", "13", "CommandComplete", ' "SELECT 1"'],
+ ["B", "5", "ReadyForQuery", " I"],
+ ]
+
+
+@pytest.mark.skipif("sys.platform == 'linux'")
+def test_trace_nonlinux(pgconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.trace(1)
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password(pgconn):
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5")
+ assert enc == b"md594839d658c28a357126f105b9cb14cfc"
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_scram(pgconn):
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256")
+ assert enc.startswith(b"SCRAM-SHA-256$")
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_badalgo(pgconn):
+ with pytest.raises(psycopg.OperationalError):
+ assert pgconn.encrypt_password(b"psycopg2", b"ashesh", b"wat")
+
+
+@pytest.mark.libpq(">= 10")
+@pytest.mark.crdb_skip("password_encryption")
+def test_encrypt_password_query(pgconn):
+ res = pgconn.exec_(b"set password_encryption to 'md5'")
+ assert res.status == pq.ExecStatus.COMMAND_OK, pgconn.error_message.decode()
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh")
+ assert enc == b"md594839d658c28a357126f105b9cb14cfc"
+
+ res = pgconn.exec_(b"set password_encryption to 'scram-sha-256'")
+ assert res.status == pq.ExecStatus.COMMAND_OK
+ enc = pgconn.encrypt_password(b"psycopg2", b"ashesh")
+ assert enc.startswith(b"SCRAM-SHA-256$")
+
+
+@pytest.mark.libpq(">= 10")
+def test_encrypt_password_closed(pgconn):
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ assert pgconn.encrypt_password(b"psycopg2", b"ashesh")
+
+
+@pytest.mark.libpq("< 10")
+def test_encrypt_password_not_supported(pgconn):
+ # it might even be supported, but not worth the lifetime
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5")
+
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256")
+
+
+def test_str(pgconn, dsn):
+ assert "[IDLE]" in str(pgconn)
+ pgconn.finish()
+ assert "[BAD]" in str(pgconn)
+
+ pgconn2 = pq.PGconn.connect_start(dsn.encode())
+ assert "[" in str(pgconn2)
+ assert "[IDLE]" not in str(pgconn2)
diff --git a/tests/pq/test_pgresult.py b/tests/pq/test_pgresult.py
new file mode 100644
index 0000000..3ad818d
--- /dev/null
+++ b/tests/pq/test_pgresult.py
@@ -0,0 +1,207 @@
+import ctypes
+import pytest
+
+from psycopg import pq
+
+
+@pytest.mark.parametrize(
+ "command, status",
+ [
+ (b"", "EMPTY_QUERY"),
+ (b"select 1", "TUPLES_OK"),
+ (b"set timezone to utc", "COMMAND_OK"),
+ (b"wat", "FATAL_ERROR"),
+ ],
+)
+def test_status(pgconn, command, status):
+ res = pgconn.exec_(command)
+ assert res.status == getattr(pq.ExecStatus, status)
+ assert status in repr(res)
+
+
+def test_clear(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.status == pq.ExecStatus.TUPLES_OK
+ res.clear()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+ res.clear()
+ assert res.status == pq.ExecStatus.FATAL_ERROR
+
+
+def test_pgresult_ptr(pgconn, libpq):
+ res = pgconn.exec_(b"select 1")
+ assert isinstance(res.pgresult_ptr, int)
+
+ f = libpq.PQcmdStatus
+ f.argtypes = [ctypes.c_void_p]
+ f.restype = ctypes.c_char_p
+ assert f(res.pgresult_ptr) == b"SELECT 1"
+
+ res.clear()
+ assert res.pgresult_ptr is None
+
+
+def test_error_message(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.error_message == b""
+ res = pgconn.exec_(b"select wat")
+ assert b"wat" in res.error_message
+ res.clear()
+ assert res.error_message == b""
+
+
+def test_error_field(pgconn):
+ res = pgconn.exec_(b"select wat")
+ # https://github.com/cockroachdb/cockroach/issues/81794
+ assert (
+ res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+ or res.error_field(pq.DiagnosticField.SEVERITY)
+ ) == b"ERROR"
+ assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703"
+ assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+ res.clear()
+ assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None
+
+
+@pytest.mark.parametrize("n", range(4))
+def test_ntuples(pgconn, n):
+ res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")])
+ assert res.ntuples == n
+ res.clear()
+ assert res.ntuples == 0
+
+
+def test_nfields(pgconn):
+ res = pgconn.exec_(b"select wat")
+ assert res.nfields == 0
+ res = pgconn.exec_(b"select 1, 2, 3")
+ assert res.nfields == 3
+ res.clear()
+ assert res.nfields == 0
+
+
+def test_fname(pgconn):
+ res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"')
+ assert res.fname(0) == b"foo"
+ assert res.fname(1) == b"BAR"
+ assert res.fname(2) is None
+ assert res.fname(-1) is None
+ res.clear()
+ assert res.fname(0) is None
+
+
+@pytest.mark.crdb("skip", reason="ftable")
+def test_ftable_and_col(pgconn):
+ res = pgconn.exec_(
+ b"""
+ drop table if exists t1, t2;
+ create table t1 as select 1 as f1;
+ create table t2 as select 2 as f2, 3 as f3;
+ """
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.exec_(
+ b"select f1, f3, 't1'::regclass::oid, 't2'::regclass::oid from t1, t2"
+ )
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+
+ assert res.ftable(0) == int(res.get_value(0, 2).decode("ascii"))
+ assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii"))
+ assert res.ftablecol(0) == 1
+ assert res.ftablecol(1) == 2
+ res.clear()
+ assert res.ftable(0) == 0
+ assert res.ftablecol(0) == 0
+
+
+@pytest.mark.parametrize("fmt", (0, 1))
+def test_fformat(pgconn, fmt):
+ res = pgconn.exec_params(b"select 1", [], result_format=fmt)
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fformat(0) == fmt
+ assert res.binary_tuples == fmt
+ res.clear()
+ assert res.fformat(0) == 0
+ assert res.binary_tuples == 0
+
+
+def test_ftype(pgconn):
+ res = pgconn.exec_(b"select 1::int4, 1::numeric, 1::text")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.ftype(0) == 23
+ assert res.ftype(1) == 1700
+ assert res.ftype(2) == 25
+ res.clear()
+ assert res.ftype(0) == 0
+
+
+def test_fmod(pgconn):
+ res = pgconn.exec_(b"select 1::int, 1::numeric(10), 1::numeric(10,2)")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fmod(0) == -1
+ assert res.fmod(1) == 0xA0004
+ assert res.fmod(2) == 0xA0006
+ res.clear()
+ assert res.fmod(0) == 0
+
+
+def test_fsize(pgconn):
+ res = pgconn.exec_(b"select 1::int4, 1::bigint, 1::text")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.fsize(0) == 4
+ assert res.fsize(1) == 8
+ assert res.fsize(2) == -1
+ res.clear()
+ assert res.fsize(0) == 0
+
+
+def test_get_value(pgconn):
+ res = pgconn.exec_(b"select 'a', '', NULL")
+ assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message
+ assert res.get_value(0, 0) == b"a"
+ assert res.get_value(0, 1) == b""
+ assert res.get_value(0, 2) is None
+ res.clear()
+ assert res.get_value(0, 0) is None
+
+
+def test_nparams_types(pgconn):
+ res = pgconn.prepare(b"", b"select $1::int4, $2::text")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ res = pgconn.describe_prepared(b"")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+ assert res.nparams == 2
+ assert res.param_type(0) == 23
+ assert res.param_type(1) == 25
+
+ res.clear()
+ assert res.nparams == 0
+ assert res.param_type(0) == 0
+
+
+def test_command_status(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.command_status == b"SELECT 1"
+ res = pgconn.exec_(b"set timezone to utc")
+ assert res.command_status == b"SET"
+ res.clear()
+ assert res.command_status is None
+
+
+def test_command_tuples(pgconn):
+ res = pgconn.exec_(b"set timezone to utf8")
+ assert res.command_tuples is None
+ res = pgconn.exec_(b"select * from generate_series(1, 10)")
+ assert res.command_tuples == 10
+ res.clear()
+ assert res.command_tuples is None
+
+
+def test_oid_value(pgconn):
+ res = pgconn.exec_(b"select 1")
+ assert res.oid_value == 0
+ res.clear()
+ assert res.oid_value == 0
diff --git a/tests/pq/test_pipeline.py b/tests/pq/test_pipeline.py
new file mode 100644
index 0000000..00cd54a
--- /dev/null
+++ b/tests/pq/test_pipeline.py
@@ -0,0 +1,161 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+
+
+@pytest.mark.libpq("< 14")
+def test_old_libpq(pgconn):
+ assert pgconn.pipeline_status == 0
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.enter_pipeline_mode()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.exit_pipeline_mode()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.pipeline_sync()
+ with pytest.raises(psycopg.NotSupportedError):
+ pgconn.send_flush_request()
+
+
+@pytest.mark.libpq(">= 14")
+def test_work_in_progress(pgconn):
+ assert not pgconn.nonblocking
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"])
+ with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"):
+ pgconn.exit_pipeline_mode()
+
+
+@pytest.mark.libpq(">= 14")
+def test_multi_pipelines(pgconn):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"], param_types=[25])
+ pgconn.pipeline_sync()
+ pgconn.send_query_params(b"select $1", [b"2"], param_types=[25])
+ pgconn.pipeline_sync()
+
+ # result from first query
+ result1 = pgconn.get_result()
+ assert result1 is not None
+ assert result1.status == pq.ExecStatus.TUPLES_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # first sync result
+ sync_result = pgconn.get_result()
+ assert sync_result is not None
+ assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # result from second query
+ result2 = pgconn.get_result()
+ assert result2 is not None
+ assert result2.status == pq.ExecStatus.TUPLES_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # second sync result
+ sync_result = pgconn.get_result()
+ assert sync_result is not None
+ assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # pipeline still ON
+ assert pgconn.pipeline_status == pq.PipelineStatus.ON
+
+ pgconn.exit_pipeline_mode()
+
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+
+ assert result1.get_value(0, 0) == b"1"
+ assert result2.get_value(0, 0) == b"2"
+
+
+@pytest.mark.libpq(">= 14")
+def test_flush_request(pgconn):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"select $1", [b"1"], param_types=[25])
+ pgconn.send_flush_request()
+ r = pgconn.get_result()
+ assert r.status == pq.ExecStatus.TUPLES_OK
+ assert r.get_value(0, 0) == b"1"
+ pgconn.exit_pipeline_mode()
+
+
+@pytest.fixture
+def table(pgconn):
+ tablename = "pipeline"
+ pgconn.exec_(f"create table {tablename} (s text)".encode("ascii"))
+ yield tablename
+ pgconn.exec_(f"drop table if exists {tablename}".encode("ascii"))
+
+
+@pytest.mark.libpq(">= 14")
+def test_pipeline_abort(pgconn, table):
+ assert pgconn.pipeline_status == pq.PipelineStatus.OFF
+ pgconn.enter_pipeline_mode()
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"1"])
+ pgconn.send_query_params(b"select no_such_function($1)", [b"1"])
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"2"])
+ pgconn.pipeline_sync()
+ pgconn.send_query_params(b"insert into pipeline values ($1)", [b"3"])
+ pgconn.pipeline_sync()
+
+ # result from first INSERT
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.COMMAND_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # error result from second query (SELECT)
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.FATAL_ERROR
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # pipeline should be aborted, due to previous error
+ assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED
+
+ # result from second INSERT, aborted due to previous error
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_ABORTED
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # pipeline is still aborted
+ assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED
+
+ # sync result
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # aborted flag is clear, pipeline is on again
+ assert pgconn.pipeline_status == pq.PipelineStatus.ON
+
+ # result from the third INSERT
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.COMMAND_OK
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ # second sync result
+ r = pgconn.get_result()
+ assert r is not None
+ assert r.status == pq.ExecStatus.PIPELINE_SYNC
+
+ # NULL signals end of result
+ assert pgconn.get_result() is None
+
+ pgconn.exit_pipeline_mode()
diff --git a/tests/pq/test_pq.py b/tests/pq/test_pq.py
new file mode 100644
index 0000000..076c3b6
--- /dev/null
+++ b/tests/pq/test_pq.py
@@ -0,0 +1,57 @@
+import os
+
+import pytest
+
+import psycopg
+from psycopg import pq
+
+from ..utils import check_libpq_version
+
+
+def test_version():
+ rv = pq.version()
+ assert rv > 90500
+ assert rv < 200000 # you are good for a while
+
+
+def test_build_version():
+ assert pq.__build_version__ and pq.__build_version__ >= 70400
+
+
+@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_BUILD')")
+def test_want_built_version():
+ want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_BUILD"]
+ got = pq.__build_version__
+ assert not check_libpq_version(got, want)
+
+
+@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_IMPORT')")
+def test_want_import_version():
+ want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_IMPORT"]
+ got = pq.version()
+ assert not check_libpq_version(got, want)
+
+
+# Note: These tests are here because test_pipeline.py tests are all skipped
+# when pipeline mode is not supported.
+
+
+@pytest.mark.libpq(">= 14")
+def test_pipeline_supported(conn):
+ assert psycopg.Pipeline.is_supported()
+ assert psycopg.AsyncPipeline.is_supported()
+
+ with conn.pipeline():
+ pass
+
+
+@pytest.mark.libpq("< 14")
+def test_pipeline_not_supported(conn):
+ assert not psycopg.Pipeline.is_supported()
+ assert not psycopg.AsyncPipeline.is_supported()
+
+ with pytest.raises(psycopg.NotSupportedError) as exc:
+ with conn.pipeline():
+ pass
+
+ assert "too old" in str(exc.value)
diff --git a/tests/scripts/bench-411.py b/tests/scripts/bench-411.py
new file mode 100644
index 0000000..82ea451
--- /dev/null
+++ b/tests/scripts/bench-411.py
@@ -0,0 +1,300 @@
+import os
+import sys
+import time
+import random
+import asyncio
+import logging
+from enum import Enum
+from typing import Any, Dict, List, Generator
+from argparse import ArgumentParser, Namespace
+from contextlib import contextmanager
+
+logger = logging.getLogger()
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(message)s",
+)
+
+
+class Driver(str, Enum):
+ psycopg2 = "psycopg2"
+ psycopg = "psycopg"
+ psycopg_async = "psycopg_async"
+ asyncpg = "asyncpg"
+
+
+ids: List[int] = []
+data: List[Dict[str, Any]] = []
+
+
+def main() -> None:
+
+ args = parse_cmdline()
+
+ ids[:] = range(args.ntests)
+ data[:] = [
+ dict(
+ id=i,
+ name="c%d" % i,
+ description="c%d" % i,
+ q=i * 10,
+ p=i * 20,
+ x=i * 30,
+ y=i * 40,
+ )
+ for i in ids
+ ]
+
+ # Must be done just on end
+ drop_at_the_end = args.drop
+ args.drop = False
+
+ for i, name in enumerate(args.drivers):
+ if i == len(args.drivers) - 1:
+ args.drop = drop_at_the_end
+
+ if name == Driver.psycopg2:
+ import psycopg2 # type: ignore
+
+ run_psycopg2(psycopg2, args)
+
+ elif name == Driver.psycopg:
+ import psycopg
+
+ run_psycopg(psycopg, args)
+
+ elif name == Driver.psycopg_async:
+ import psycopg
+
+ if sys.platform == "win32":
+ if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ asyncio.set_event_loop_policy(
+ asyncio.WindowsSelectorEventLoopPolicy()
+ )
+
+ asyncio.run(run_psycopg_async(psycopg, args))
+
+ elif name == Driver.asyncpg:
+ import asyncpg # type: ignore
+
+ asyncio.run(run_asyncpg(asyncpg, args))
+
+ else:
+ raise AssertionError(f"unknown driver: {name!r}")
+
+ # Must be done just on start
+ args.create = False
+
+
+table = """
+CREATE TABLE customer (
+ id SERIAL NOT NULL,
+ name VARCHAR(255),
+ description VARCHAR(255),
+ q INTEGER,
+ p INTEGER,
+ x INTEGER,
+ y INTEGER,
+ z INTEGER,
+ PRIMARY KEY (id)
+)
+"""
+drop = "DROP TABLE IF EXISTS customer"
+
+insert = """
+INSERT INTO customer (id, name, description, q, p, x, y) VALUES
+(%(id)s, %(name)s, %(description)s, %(q)s, %(p)s, %(x)s, %(y)s)
+"""
+
+select = """
+SELECT customer.id, customer.name, customer.description, customer.q,
+ customer.p, customer.x, customer.y, customer.z
+FROM customer
+WHERE customer.id = %(id)s
+"""
+
+
+@contextmanager
+def time_log(message: str) -> Generator[None, None, None]:
+ start = time.monotonic()
+ yield
+ end = time.monotonic()
+ logger.info(f"Run {message} in {end-start} s")
+
+
+def run_psycopg2(psycopg2: Any, args: Namespace) -> None:
+ logger.info("Running psycopg2")
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ with psycopg2.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ cursor.execute(table)
+ cursor.executemany(insert, data)
+ conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ with psycopg2.connect(args.dsn) as conn:
+ with time_log("psycopg2"):
+ for id_ in to_query:
+ with conn.cursor() as cursor:
+ cursor.execute(select, {"id": id_})
+ cursor.fetchall()
+ # conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ with psycopg2.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ conn.commit()
+
+
+def run_psycopg(psycopg: Any, args: Namespace) -> None:
+ logger.info("Running psycopg sync")
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ with psycopg.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ cursor.execute(table)
+ cursor.executemany(insert, data)
+ conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ with psycopg.connect(args.dsn) as conn:
+ with time_log("psycopg"):
+ for id_ in to_query:
+ with conn.cursor() as cursor:
+ cursor.execute(select, {"id": id_})
+ cursor.fetchall()
+ # conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ with psycopg.connect(args.dsn) as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(drop)
+ conn.commit()
+
+
+async def run_psycopg_async(psycopg: Any, args: Namespace) -> None:
+ logger.info("Running psycopg async")
+
+ conn: Any
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ async with conn.cursor() as cursor:
+ await cursor.execute(drop)
+ await cursor.execute(table)
+ await cursor.executemany(insert, data)
+ await conn.commit()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ with time_log("psycopg_async"):
+ for id_ in to_query:
+ cursor = await conn.execute(select, {"id": id_})
+ await cursor.fetchall()
+ await cursor.close()
+ # await conn.rollback()
+
+ if args.drop:
+ logger.info("dropping test records")
+ async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+ async with conn.cursor() as cursor:
+ await cursor.execute(drop)
+ await conn.commit()
+
+
+async def run_asyncpg(asyncpg: Any, args: Namespace) -> None:
+ logger.info("Running asyncpg")
+
+ places = dict(id="$1", name="$2", description="$3", q="$4", p="$5", x="$6", y="$7")
+ a_insert = insert % places
+ a_select = select % {"id": "$1"}
+
+ conn: Any
+
+ if args.create:
+ logger.info(f"inserting {args.ntests} test records")
+ conn = await asyncpg.connect(args.dsn)
+ async with conn.transaction():
+ await conn.execute(drop)
+ await conn.execute(table)
+ await conn.executemany(a_insert, [tuple(d.values()) for d in data])
+ await conn.close()
+
+ logger.info(f"running {args.ntests} queries")
+ to_query = random.choices(ids, k=args.ntests)
+ conn = await asyncpg.connect(args.dsn)
+ with time_log("asyncpg"):
+ for id_ in to_query:
+ tr = conn.transaction()
+ await tr.start()
+ await conn.fetch(a_select, id_)
+ # await tr.rollback()
+ await conn.close()
+
+ if args.drop:
+ logger.info("dropping test records")
+ conn = await asyncpg.connect(args.dsn)
+ async with conn.transaction():
+ await conn.execute(drop)
+ await conn.close()
+
+
+def parse_cmdline() -> Namespace:
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "drivers",
+ nargs="+",
+ metavar="DRIVER",
+ type=Driver,
+ help=f"the drivers to test [choices: {', '.join(d.value for d in Driver)}]",
+ )
+
+ parser.add_argument(
+ "--ntests",
+ type=int,
+ default=10_000,
+ help="number of tests to perform [default: %(default)s]",
+ )
+
+ parser.add_argument(
+ "--dsn",
+ default=os.environ.get("PSYCOPG_TEST_DSN", ""),
+ help="database connection string"
+ " [default: %(default)r (from PSYCOPG_TEST_DSN env var)]",
+ )
+
+ parser.add_argument(
+ "--no-create",
+ dest="create",
+ action="store_false",
+ default="True",
+ help="skip data creation before tests (it must exist already)",
+ )
+
+ parser.add_argument(
+ "--no-drop",
+ dest="drop",
+ action="store_false",
+ default="True",
+ help="skip data drop after tests",
+ )
+
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/dectest.py b/tests/scripts/dectest.py
new file mode 100644
index 0000000..a49f116
--- /dev/null
+++ b/tests/scripts/dectest.py
@@ -0,0 +1,51 @@
+"""
+A quick and rough performance comparison of text vs. binary Decimal adaptation
+"""
+from random import randrange
+from decimal import Decimal
+import psycopg
+from psycopg import sql
+
+ncols = 10
+nrows = 500000
+format = psycopg.pq.Format.BINARY
+test = "copy"
+
+
+def main() -> None:
+ cnn = psycopg.connect()
+
+ cnn.execute(
+ sql.SQL("create table testdec ({})").format(
+ sql.SQL(", ").join(
+ [
+ sql.SQL("{} numeric(10,2)").format(sql.Identifier(f"t{i}"))
+ for i in range(ncols)
+ ]
+ )
+ )
+ )
+ cur = cnn.cursor()
+
+ if test == "copy":
+ with cur.copy(f"copy testdec from stdin (format {format.name})") as copy:
+ for j in range(nrows):
+ copy.write_row(
+ [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
+ )
+
+ elif test == "insert":
+ ph = ["%t", "%b"][format]
+ cur.executemany(
+ "insert into testdec values (%s)" % ", ".join([ph] * ncols),
+ (
+ [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
+ for j in range(nrows)
+ ),
+ )
+ else:
+ raise Exception(f"bad test: {test}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py
new file mode 100644
index 0000000..ec95229
--- /dev/null
+++ b/tests/scripts/pipeline-demo.py
@@ -0,0 +1,340 @@
+"""Pipeline mode demo
+
+This reproduces libpq_pipeline::pipelined_insert PostgreSQL test at
+src/test/modules/libpq_pipeline/libpq_pipeline.c::test_pipelined_insert().
+
+We do not fetch results explicitly (using cursor.fetch*()), this is
+handled by execute() calls when pgconn socket is read-ready, which
+happens when the output buffer is full.
+"""
+import argparse
+import asyncio
+import logging
+from contextlib import contextmanager
+from functools import partial
+from typing import Any, Iterator, Optional, Sequence, Tuple
+
+from psycopg import AsyncConnection, Connection
+from psycopg import pq, waiting
+from psycopg import errors as e
+from psycopg.abc import PipelineCommand
+from psycopg.generators import pipeline_communicate
+from psycopg.pq import Format, DiagnosticField
+from psycopg._compat import Deque
+
+psycopg_logger = logging.getLogger("psycopg")
+pipeline_logger = logging.getLogger("pipeline")
+args: argparse.Namespace
+
+
+class LoggingPGconn:
+ """Wrapper for PGconn that logs fetched results."""
+
+ def __init__(self, pgconn: pq.abc.PGconn, logger: logging.Logger):
+ self._pgconn = pgconn
+ self._logger = logger
+
+ def log_notice(result: pq.abc.PGresult) -> None:
+ def get_field(field: DiagnosticField) -> Optional[str]:
+ value = result.error_field(field)
+ return value.decode("utf-8", "replace") if value else None
+
+ logger.info(
+ "notice %s %s",
+ get_field(DiagnosticField.SEVERITY),
+ get_field(DiagnosticField.MESSAGE_PRIMARY),
+ )
+
+ pgconn.notice_handler = log_notice
+
+ if args.trace:
+ self._trace_file = open(args.trace, "w")
+ pgconn.trace(self._trace_file.fileno())
+
+ def __del__(self) -> None:
+ if hasattr(self, "_trace_file"):
+ self._pgconn.untrace()
+ self._trace_file.close()
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._pgconn, name)
+
+ def send_query(self, command: bytes) -> None:
+ self._logger.warning("PQsendQuery broken in libpq 14.5")
+ self._pgconn.send_query(command)
+ self._logger.info("sent %s", command.decode())
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ self._pgconn.send_query_params(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._logger.info("sent %s", command.decode())
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ self._pgconn.send_query_prepared(
+ name, param_values, param_formats, result_format
+ )
+ self._logger.info("sent prepared '%s' with %s", name.decode(), param_values)
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ self._pgconn.send_prepare(name, command, param_types)
+ self._logger.info("prepare %s as '%s'", command.decode(), name.decode())
+
+ def get_result(self) -> Optional[pq.abc.PGresult]:
+ r = self._pgconn.get_result()
+ if r is not None:
+ self._logger.info("got %s result", pq.ExecStatus(r.status).name)
+ return r
+
+
+@contextmanager
+def prepare_pipeline_demo_pq(
+ pgconn: LoggingPGconn, rows_to_send: int, logger: logging.Logger
+) -> Iterator[Tuple[Deque[PipelineCommand], Deque[str]]]:
+ """Set up pipeline demo with initial queries and yield commands and
+ results queue for pipeline_communicate().
+ """
+ logger.debug("enter pipeline")
+ pgconn.enter_pipeline_mode()
+
+ setup_queries = [
+ ("begin", "BEGIN TRANSACTION"),
+ ("drop table", "DROP TABLE IF EXISTS pq_pipeline_demo"),
+ (
+ "create table",
+ (
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ ),
+ ),
+ (
+ "prepare",
+ ("INSERT INTO pq_pipeline_demo(itemno, int8filler)" " VALUES ($1, $2)"),
+ ),
+ ]
+
+ commands = Deque[PipelineCommand]()
+ results_queue = Deque[str]()
+
+ for qname, query in setup_queries:
+ if qname == "prepare":
+ pgconn.send_prepare(qname.encode(), query.encode())
+ else:
+ pgconn.send_query_params(query.encode(), None)
+ results_queue.append(qname)
+
+ committed = False
+ synced = False
+
+ while True:
+ if rows_to_send:
+ params = [f"{rows_to_send}".encode(), f"{1 << 62}".encode()]
+ commands.append(partial(pgconn.send_query_prepared, b"prepare", params))
+ results_queue.append(f"row {rows_to_send}")
+ rows_to_send -= 1
+
+ elif not committed:
+ committed = True
+ commands.append(partial(pgconn.send_query_params, b"COMMIT", None))
+ results_queue.append("commit")
+
+ elif not synced:
+
+ def sync() -> None:
+ pgconn.pipeline_sync()
+ logger.info("pipeline sync sent")
+
+ synced = True
+ commands.append(sync)
+ results_queue.append("sync")
+
+ else:
+ break
+
+ try:
+ yield commands, results_queue
+ finally:
+ logger.debug("exit pipeline")
+ pgconn.exit_pipeline_mode()
+
+
+def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None:
+ pgconn = LoggingPGconn(Connection.connect().pgconn, logger)
+ with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as (
+ commands,
+ results_queue,
+ ):
+ while results_queue:
+ fetched = waiting.wait(
+ pipeline_communicate(
+ pgconn, # type: ignore[arg-type]
+ commands,
+ ),
+ pgconn.socket,
+ )
+ assert not commands, commands
+ for results in fetched:
+ results_queue.popleft()
+ for r in results:
+ if r.status in (
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ):
+ raise e.error_from_result(r)
+
+
+async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> None:
+ pgconn = LoggingPGconn((await AsyncConnection.connect()).pgconn, logger)
+
+ with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as (
+ commands,
+ results_queue,
+ ):
+ while results_queue:
+ fetched = await waiting.wait_async(
+ pipeline_communicate(
+ pgconn, # type: ignore[arg-type]
+ commands,
+ ),
+ pgconn.socket,
+ )
+ assert not commands, commands
+ for results in fetched:
+ results_queue.popleft()
+ for r in results:
+ if r.status in (
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ):
+ raise e.error_from_result(r)
+
+
+def pipeline_demo(rows_to_send: int, many: bool, logger: logging.Logger) -> None:
+ """Pipeline demo using sync API."""
+ conn = Connection.connect()
+ conn.autocommit = True
+ conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment]
+ with conn.pipeline():
+ with conn.transaction():
+ conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ conn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)"
+ params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1))
+ if many:
+ cur = conn.cursor()
+ cur.executemany(query, list(params))
+ else:
+ for p in params:
+ conn.execute(query, p)
+
+
+async def pipeline_demo_async(
+ rows_to_send: int, many: bool, logger: logging.Logger
+) -> None:
+ """Pipeline demo using async API."""
+ aconn = await AsyncConnection.connect()
+ await aconn.set_autocommit(True)
+ aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment]
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo")
+ await aconn.execute(
+ "CREATE UNLOGGED TABLE pq_pipeline_demo("
+ " id serial primary key,"
+ " itemno integer,"
+ " int8filler int8"
+ ")"
+ )
+ query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)"
+ params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1))
+ if many:
+ cur = aconn.cursor()
+ await cur.executemany(query, list(params))
+ else:
+ for p in params:
+ await aconn.execute(query, p)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-n",
+ dest="nrows",
+ metavar="ROWS",
+ default=10_000,
+ type=int,
+ help="number of rows to insert",
+ )
+ parser.add_argument(
+ "--pq", action="store_true", help="use low-level psycopg.pq API"
+ )
+ parser.add_argument(
+ "--async", dest="async_", action="store_true", help="use async API"
+ )
+ parser.add_argument(
+ "--many",
+ action="store_true",
+ help="use executemany() (not applicable for --pq)",
+ )
+ parser.add_argument("--trace", help="write trace info into TRACE file")
+ parser.add_argument("-l", "--log", help="log file (stderr by default)")
+
+ global args
+ args = parser.parse_args()
+
+ psycopg_logger.setLevel(logging.DEBUG)
+ pipeline_logger.setLevel(logging.DEBUG)
+ if args.log:
+ psycopg_logger.addHandler(logging.FileHandler(args.log))
+ pipeline_logger.addHandler(logging.FileHandler(args.log))
+ else:
+ psycopg_logger.addHandler(logging.StreamHandler())
+ pipeline_logger.addHandler(logging.StreamHandler())
+
+ if args.pq:
+ if args.many:
+ parser.error("--many cannot be used with --pq")
+ if args.async_:
+ asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger))
+ else:
+ pipeline_demo_pq(args.nrows, pipeline_logger)
+ else:
+ if pq.__impl__ != "python":
+ parser.error(
+ "only supported for Python implementation (set PSYCOPG_IMPL=python)"
+ )
+ if args.async_:
+ asyncio.run(pipeline_demo_async(args.nrows, args.many, pipeline_logger))
+ else:
+ pipeline_demo(args.nrows, args.many, pipeline_logger)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/scripts/spiketest.py b/tests/scripts/spiketest.py
new file mode 100644
index 0000000..2c9cc16
--- /dev/null
+++ b/tests/scripts/spiketest.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+"""
+Run a connection pool spike test.
+
+The test is inspired to the `spike analysis`__ illustrated by HikariCP
+
+.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/
+ Welcome-To-The-Jungle.md
+
+"""
+# mypy: allow-untyped-defs
+# mypy: allow-untyped-calls
+
+import time
+import threading
+
+import psycopg
+import psycopg_pool
+from psycopg.rows import Row
+
+import logging
+
+
+def main() -> None:
+ opt = parse_cmdline()
+ if opt.loglevel:
+ loglevel = getattr(logging, opt.loglevel.upper())
+ logging.basicConfig(
+ level=loglevel, format="%(asctime)s %(levelname)s %(message)s"
+ )
+
+ logging.getLogger("psycopg2.pool").setLevel(loglevel)
+
+ with psycopg_pool.ConnectionPool(
+ opt.dsn,
+ min_size=opt.min_size,
+ max_size=opt.max_size,
+ connection_class=DelayedConnection,
+ kwargs={"conn_delay": 0.150},
+ ) as pool:
+ pool.wait()
+ measurer = Measurer(pool)
+
+ # Create and start all the thread: they will get stuck on the event
+ ev = threading.Event()
+ threads = [
+ threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True)
+ for i in range(opt.num_clients)
+ ]
+ for t in threads:
+ t.start()
+ time.sleep(0.2)
+
+ # Release the threads!
+ measurer.start(0.00025)
+ t0 = time.time()
+ ev.set()
+
+ # Wait for the threads to finish
+ for t in threads:
+ t.join()
+ t1 = time.time()
+ measurer.stop()
+
+ print(f"time: {(t1 - t0) * 1000} msec")
+ print("active,idle,total,waiting")
+ recs = [
+ f'{m["pool_size"] - m["pool_available"]}'
+ f',{m["pool_available"]}'
+ f',{m["pool_size"]}'
+ f',{m["requests_waiting"]}'
+ for m in measurer.measures
+ ]
+ print("\n".join(recs))
+
+
+def worker(p, t, ev):
+ ev.wait()
+ with p.connection():
+ time.sleep(t)
+
+
+class Measurer:
+ def __init__(self, pool):
+ self.pool = pool
+ self.worker = None
+ self.stopped = False
+ self.measures = []
+
+ def start(self, interval):
+ self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True)
+ self.worker.start()
+
+ def stop(self):
+ self.stopped = True
+ if self.worker:
+ self.worker.join()
+ self.worker = None
+
+ def _run(self, interval):
+ while not self.stopped:
+ self.measures.append(self.pool.get_stats())
+ time.sleep(interval)
+
+
+class DelayedConnection(psycopg.Connection[Row]):
+ """A connection adding a delay to the connection time."""
+
+ @classmethod
+ def connect(cls, conninfo, conn_delay=0, **kwargs):
+ t0 = time.time()
+ conn = super().connect(conninfo, **kwargs)
+ t1 = time.time()
+ wait = max(0.0, conn_delay - (t1 - t0))
+ if wait:
+ time.sleep(wait)
+ return conn
+
+
+def parse_cmdline():
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument("--dsn", default="", help="connection string to the database")
+ parser.add_argument(
+ "--min_size",
+ default=5,
+ type=int,
+ help="minimum number of connections in the pool",
+ )
+ parser.add_argument(
+ "--max_size",
+ default=20,
+ type=int,
+ help="maximum number of connections in the pool",
+ )
+ parser.add_argument(
+ "--num-clients",
+ default=50,
+ type=int,
+ help="number of threads making a request",
+ )
+ parser.add_argument(
+ "--loglevel",
+ default=None,
+ choices=("DEBUG", "INFO", "WARNING", "ERROR"),
+ help="level to log at [default: no log]",
+ )
+
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/test_adapt.py b/tests/test_adapt.py
new file mode 100644
index 0000000..2190a84
--- /dev/null
+++ b/tests/test_adapt.py
@@ -0,0 +1,530 @@
+import datetime as dt
+from types import ModuleType
+from typing import Any, List
+
+import pytest
+
+import psycopg
+from psycopg import pq, sql, postgres
+from psycopg import errors as e
+from psycopg.adapt import Transformer, PyFormat, Dumper, Loader
+from psycopg._cmodule import _psycopg
+from psycopg.postgres import types as builtins, TEXT_OID
+from psycopg.types.array import ListDumper, ListBinaryDumper
+
+
+@pytest.mark.parametrize(
+ "data, format, result, type",
+ [
+ (1, PyFormat.TEXT, b"1", "int2"),
+ ("hello", PyFormat.TEXT, b"hello", "text"),
+ ("hello", PyFormat.BINARY, b"hello", "text"),
+ ],
+)
+def test_dump(data, format, result, type):
+ t = Transformer()
+ dumper = t.get_dumper(data, format)
+ assert dumper.dump(data) == result
+ if type == "text" and format != PyFormat.BINARY:
+ assert dumper.oid == 0
+ else:
+ assert dumper.oid == builtins[type].oid
+
+
+@pytest.mark.parametrize(
+ "data, result",
+ [
+ (1, b"1"),
+ ("hello", b"'hello'"),
+ ("he'llo", b"'he''llo'"),
+ (True, b"true"),
+ (None, b"NULL"),
+ ],
+)
+def test_quote(data, result):
+ t = Transformer()
+ dumper = t.get_dumper(data, PyFormat.TEXT)
+ assert dumper.quote(data) == result
+
+
+def test_register_dumper_by_class(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(MyStr, dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+def test_register_dumper_by_class_name(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn):
+ psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_dump_connection_ctx(conn):
+ conn.adapters.register_dumper(MyStr, make_bin_dumper("b"))
+ conn.adapters.register_dumper(MyStr, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_cursor_ctx(conn):
+ conn.adapters.register_dumper(str, make_bin_dumper("b"))
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, make_bin_dumper("bc"))
+ cur.adapters.register_dumper(str, make_dumper("tc"))
+
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_subclass(conn):
+ class MyString(str):
+ pass
+
+ cur = conn.cursor()
+ cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")])
+ assert cur.fetchone() == ("hello", "world")
+
+
+def test_subclass_dumper(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import StrDumper
+
+ class MyStrDumper(StrDumper):
+ def dump(self, obj):
+ return (obj * 2).encode()
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello"
+
+
+def test_dumper_protocol(conn):
+
+ # This class doesn't inherit from adapt.Dumper but passes a mypy check
+ from .adapters_example import MyStrDumper
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select %s", [["hi", "ha"]])
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+ assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+
+
+def test_loader_protocol(conn):
+
+ # This class doesn't inherit from adapt.Loader but passes a mypy check
+ from .adapters_example import MyTextLoader
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ cur = conn.execute("select 'hello'::text")
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select '{hi,ha}'::text[]")
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+
+
+def test_subclass_loader(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import TextLoader
+
+ class MyTextLoader(TextLoader):
+ def load(self, data):
+ return (bytes(data) * 2).decode()
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello"
+
+
+@pytest.mark.parametrize(
+ "data, format, type, result",
+ [
+ (b"1", pq.Format.TEXT, "int4", 1),
+ (b"hello", pq.Format.TEXT, "text", "hello"),
+ (b"hello", pq.Format.BINARY, "text", "hello"),
+ ],
+)
+def test_cast(data, format, type, result):
+ t = Transformer()
+ rv = t.get_loader(builtins[type].oid, format).load(data)
+ assert rv == result
+
+
+def test_register_loader_by_oid(conn):
+ assert TEXT_OID == 25
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader(TEXT_OID, loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+def test_register_loader_by_type_name(conn):
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader("text", loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_load_global_ctx(conn_cls, dsn, global_adapters):
+ psycopg.adapters.register_loader("text", make_loader("gt"))
+ psycopg.adapters.register_loader("text", make_bin_loader("gb"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
+
+
+def test_load_connection_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ r = conn.cursor(binary=False).execute("select 'hello'::text").fetchone()
+ assert r == ("hellot",)
+ r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone()
+ assert r == ("hellob",)
+
+
+def test_load_cursor_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", make_loader("tc"))
+ cur.adapters.register_loader("text", make_bin_loader("bc"))
+
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellotc",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellot",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellob",)
+
+
+def test_cow_dumpers(conn):
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_dumper(str, make_dumper("c2"))
+
+ r = cur1.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+ conn.adapters.register_dumper(str, make_dumper("t1"))
+ r = cur1.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+
+def test_cow_loaders(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_loader("text", make_loader("c2"))
+
+ assert cur1.execute("select 'hello'::text").fetchone() == ("hellot",)
+ assert cur2.execute("select 'hello'::text").fetchone() == ("helloc2",)
+
+ conn.adapters.register_loader("text", make_loader("t1"))
+ assert cur1.execute("select 'hello2'::text").fetchone() == ("hello2t",)
+ assert cur2.execute("select 'hello2'::text").fetchone() == ("hello2c2",)
+
+
+@pytest.mark.parametrize(
+ "sql, obj",
+ [("'{hello}'::text[]", ["helloc"]), ("row('hello'::text)", ("helloc",))],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
+ cur = conn.cursor(binary=fmt_out == pq.Format.BINARY)
+ if fmt_out == pq.Format.TEXT:
+ cur.adapters.register_loader("text", make_loader("c"))
+ else:
+ cur.adapters.register_loader("text", make_bin_loader("c"))
+
+ cur.execute(f"select {sql}")
+ res = cur.fetchone()[0]
+ assert res == obj
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_list_dumper(conn, fmt_out):
+ t = Transformer(conn)
+ fmt_in = PyFormat.from_pq(fmt_out)
+ dint = t.get_dumper([0], fmt_in)
+ assert isinstance(dint, (ListDumper, ListBinaryDumper))
+ assert dint.oid == builtins["int2"].array_oid
+ assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
+
+ dstr = t.get_dumper([""], fmt_in)
+ assert dstr is not dint
+
+ assert t.get_dumper([1], fmt_in) is dint
+ assert t.get_dumper([None, [1]], fmt_in) is dint
+
+ dempty = t.get_dumper([], fmt_in)
+ assert t.get_dumper([None, [None]], fmt_in) is dempty
+ assert dempty.oid == 0
+ assert dempty.dump([]) == b"{}"
+
+ L: List[List[Any]] = []
+ L.append(L)
+ with pytest.raises(psycopg.DataError):
+ assert t.get_dumper(L, fmt_in)
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == 0
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == 0
+
+
+def test_str_list_dumper_binary(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.BINARY)
+ assert isinstance(dstr, ListBinaryDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+def test_last_dumper_registered_ctx(conn):
+ cur = conn.cursor()
+
+ bd = make_bin_dumper("b")
+ cur.adapters.register_dumper(str, bd)
+ td = make_dumper("t")
+ cur.adapters.register_dumper(str, td)
+
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %t", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %b", ["hello"]).fetchone()[0] == "hellob"
+
+ cur.adapters.register_dumper(str, bd)
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellob"
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.TEXT, PyFormat.BINARY])
+def test_none_type_argument(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table none_args (id serial primary key, num integer)")
+ cur.execute("insert into none_args (num) values (%s) returning id", (None,))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as unknown oid to libpq. This is because
+ # unknown is more easily cast by postgres to different types (see jsonb
+ # later).
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ if fmt_in != PyFormat.BINARY:
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+ else:
+ # Binary types cannot be passed as unknown oids.
+ with pytest.raises(e.DatatypeMismatch):
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_no_cast_needed(conn, fmt_in):
+ # Verify that there is no need of cast in certain common scenario
+ cur = conn.execute(f"select '2021-01-01'::date + %{fmt_in.value}", [3])
+ assert cur.fetchone()[0] == dt.date(2021, 1, 4)
+
+ cur = conn.execute(f"select '[10, 20, 30]'::jsonb -> %{fmt_in.value}", [1])
+ assert cur.fetchone()[0] == 20
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
+def test_optimised_adapters():
+
+ # All the optimised adapters available
+ c_adapters = {}
+ for n in dir(_psycopg):
+ if n.startswith("_") or n in ("CDumper", "CLoader"):
+ continue
+ obj = getattr(_psycopg, n)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(
+ obj,
+ (_psycopg.CDumper, _psycopg.CLoader), # type: ignore[attr-defined]
+ ):
+ continue
+ c_adapters[n] = obj
+
+ # All the registered adapters
+ reg_adapters = set()
+ adapters = list(postgres.adapters._dumpers.values()) + postgres.adapters._loaders
+ assert len(adapters) == 5
+ for m in adapters:
+ reg_adapters |= set(m.values())
+
+ # Check that the registered adapters are the optimised one
+ i = 0
+ for cls in reg_adapters:
+ if cls.__name__ in c_adapters:
+ assert cls is c_adapters[cls.__name__]
+ i += 1
+
+ assert i >= 10
+
+ # Check that every optimised adapter is the optimised version of a Py one
+ for n in dir(psycopg.types):
+ mod = getattr(psycopg.types, n)
+ if not isinstance(mod, ModuleType):
+ continue
+ for n1 in dir(mod):
+ obj = getattr(mod, n1)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(obj, (Dumper, Loader)):
+ continue
+ c_adapters.pop(obj.__name__, None)
+
+ assert not c_adapters
+
+
+def test_dumper_init_error(conn):
+ class BadDumper(Dumper):
+ def __init__(self, cls, context):
+ super().__init__(cls, context)
+ 1 / 0
+
+ def dump(self, obj):
+ return obj.encode()
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, BadDumper)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select %s::text", ["hi"])
+
+
+def test_loader_init_error(conn):
+ class BadLoader(Loader):
+ def __init__(self, oid, context):
+ super().__init__(oid, context)
+ 1 / 0
+
+ def load(self, data):
+ return data.decode()
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", BadLoader)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 'hi'::text")
+ assert cur.fetchone() == ("hi",)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_random(conn, faker, fmt, fmt_out):
+ faker.format = fmt
+ faker.choose_schema(ncols=20)
+ faker.make_records(50)
+
+ with conn.cursor(binary=fmt_out) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+class MyStr(str):
+ pass
+
+
+def make_dumper(suffix):
+ """Create a test dumper appending a suffix to the bytes representation."""
+
+ class TestDumper(Dumper):
+ oid = TEXT_OID
+ format = pq.Format.TEXT
+
+ def dump(self, s):
+ return (s + suffix).encode("ascii")
+
+ return TestDumper
+
+
+def make_bin_dumper(suffix):
+ cls = make_dumper(suffix)
+ cls.format = pq.Format.BINARY
+ return cls
+
+
+def make_loader(suffix):
+ """Create a test loader appending a suffix to the data returned."""
+
+ class TestLoader(Loader):
+ format = pq.Format.TEXT
+
+ def load(self, b):
+ return bytes(b).decode("ascii") + suffix
+
+ return TestLoader
+
+
+def make_bin_loader(suffix):
+ cls = make_loader(suffix)
+ cls.format = pq.Format.BINARY
+ return cls
diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py
new file mode 100644
index 0000000..b355604
--- /dev/null
+++ b/tests/test_client_cursor.py
@@ -0,0 +1,855 @@
+import pickle
+import weakref
+import datetime as dt
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import sql, rows
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+
+from .utils import gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+
+
+@pytest.fixture
+def conn(conn):
+ conn.cursor_factory = psycopg.ClientCursor
+ return conn
+
+
+def test_init(conn):
+ cur = psycopg.ClientCursor(conn)
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ conn.row_factory = rows.dict_row
+ cur = psycopg.ClientCursor(conn)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_init_factory(conn):
+ cur = psycopg.ClientCursor(conn, row_factory=rows.dict_row)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_from_cursor_factory(conn_cls, dsn):
+ with conn_cls.connect(dsn, cursor_factory=psycopg.ClientCursor) as conn:
+ cur = conn.cursor()
+ assert type(cur) is psycopg.ClientCursor
+
+ cur.execute("select %s", (1,))
+ assert cur.fetchone() == (1,)
+ assert cur._query
+ assert cur._query.query == b"select 1"
+
+
+def test_close(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.execute("select 'foo'")
+
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchall()
+
+
+def test_context(conn):
+ with conn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+def test_weakref(conn):
+ cur = conn.cursor()
+ w = weakref.ref(cur)
+ cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_statusmessage(conn):
+ cur = conn.cursor()
+ assert cur.statusmessage is None
+
+ cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+def test_execute_sql(conn):
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {value}").format(value="hello"))
+ assert cur.fetchone() == ("hello",)
+
+
+def test_execute_many_results(conn):
+ cur = conn.cursor()
+ assert cur.nextset() is None
+
+ rv = cur.execute("select %s; select generate_series(1,%s)", ("foo", 3))
+ assert rv is cur
+ assert cur.fetchall() == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.nextset() is None
+
+ cur.close()
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+ cur = conn.cursor()
+ rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+def test_execute_type_change(conn):
+ # issue #112
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.execute(sql, (1,))
+ cur.execute(sql, (100_000,))
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+def test_executemany_type_change(conn):
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.executemany(sql, [(1,), (100_000,)])
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+def test_execute_copy(conn, query):
+ cur = conn.cursor()
+ cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute(query)
+
+
+def test_fetchone(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = cur.fetchone()
+ assert row == (1, "foo", None)
+ row = cur.fetchone()
+ assert row is None
+
+
+def test_binary_cursor_execute(conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None])
+
+
+def test_execute_binary(conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = conn.cursor()
+ cur.execute("select %s, %s", [1, None], binary=True)
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None], binary=False)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_query_encode(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("select '\u20ac'").fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_query_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+def test_executemany_no_data(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+def test_executemany_rowcount(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+def test_executemany_returning(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_returning_discard(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+def test_executemany_no_result(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+def test_executemany_rowcount_no_hit(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ # This fails, but only because we try to copy in pipeline mode,
+ # crashing the connection. Which would be even fine, but with
+ # the async cursor it's worse... See test_client_cursor_async.py.
+ # "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_executemany_null_first(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table testmany (a bigint, b bigint)")
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+def test_rowcount(conn):
+ cur = conn.cursor()
+
+ cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+
+def test_rownumber(conn):
+ cur = conn.cursor()
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+def test_iter(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ for rec in cur:
+ assert rec == (1,)
+ break
+
+ for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert cur.fetchone() == (3,)
+ assert list(cur) == []
+
+
+def test_row_factory(conn):
+ cur = conn.cursor(row_factory=my_row_factory)
+
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+ cur.execute("select 'foo' as bar")
+ (r,) = cur.fetchone()
+ assert r == "FOObar"
+
+ cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert cur.fetchall() == [["Yy", "Zz"]]
+
+ cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"y": "y", "z": "z"}
+
+
+def test_row_factory_none(conn):
+ cur = conn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ r = cur.execute("select 1 as a, 2 as b").fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+def test_bad_row_factory(conn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = conn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = conn.cursor(row_factory=broken_maker)
+ cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ cur.fetchone()
+
+
+def test_scroll(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.scroll(0)
+
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(-1)
+ assert cur.fetchone() == (8,)
+ cur.scroll(-2)
+ assert cur.fetchone() == (7,)
+ cur.scroll(2, mode="absolute")
+ assert cur.fetchone() == (2,)
+
+ # on the boundary
+ cur.scroll(0, mode="absolute")
+ assert cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ cur.scroll(-1, mode="absolute")
+
+ cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(-1)
+
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ cur.scroll(10, mode="absolute")
+
+ cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ cur.scroll(1, "wat")
+
+
+def test_query_params_execute(conn):
+ cur = conn.cursor()
+ assert cur._query is None
+
+ cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select 1, NULL::text"
+ assert cur._query.params == (b"1", b"NULL")
+
+ cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select 'wat'::int"
+ assert cur._query.params == (b"'wat'",)
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select %(x)s", {"x": 1}, (1,)),
+ ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)),
+ ("select %(x)s, %(x)s", {"x": 1}, (1, 1)),
+ ],
+)
+def test_query_params_named(conn, query, params, want):
+ cur = conn.cursor()
+ cur.execute(query, params)
+ rec = cur.fetchone()
+ assert rec == want
+
+
+def test_query_params_executemany(conn):
+ cur = conn.cursor()
+
+ cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select 3, 4"
+ assert cur._query.params == (b"3", b"4")
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_stream(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+class TestColumn:
+ def test_description_attribs(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ assert len(curs.description) == 3
+ for c in curs.description:
+ len(c) == 7 # DBAPI happy
+ for i, a in enumerate(
+ """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ ):
+ assert c[i] == getattr(c, a)
+
+ # Won't fill them up
+ assert c.null_ok is None
+
+ c = curs.description[0]
+ assert c.name == "pi"
+ assert c.type_code == builtins["numeric"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision == 10
+ assert c.scale == 2
+
+ c = curs.description[1]
+ assert c.name == "hi"
+ assert c.type_code == builtins["text"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision is None
+ assert c.scale is None
+
+ c = curs.description[2]
+ assert c.name == "now"
+ assert c.type_code == builtins["date"].oid
+ assert c.display_size is None
+ if is_crdb(conn):
+ assert c.internal_size == 16
+ else:
+ assert c.internal_size == 4
+ assert c.precision is None
+ assert c.scale is None
+
+ def test_description_slice(self, conn):
+ curs = conn.cursor()
+ curs.execute("select 1::int as a")
+ curs.description[0][0:2] == ("a", 23)
+
+ @pytest.mark.parametrize(
+ "type, precision, scale, dsize, isize",
+ [
+ ("text", None, None, None, None),
+ ("varchar", None, None, None, None),
+ ("varchar(42)", None, None, 42, None),
+ ("int4", None, None, None, 4),
+ ("numeric", None, None, None, None),
+ ("numeric(10)", 10, 0, None, None),
+ ("numeric(10, 3)", 10, 3, None, None),
+ ("time", None, None, None, 8),
+ crdb_time_precision("time(4)", 4, None, None, 8),
+ crdb_time_precision("time(10)", 6, None, None, 8),
+ ],
+ )
+ def test_details(self, conn, type, precision, scale, dsize, isize):
+ cur = conn.cursor()
+ cur.execute(f"select null::{type}")
+ col = cur.description[0]
+ repr(col)
+ assert col.precision == precision
+ assert col.scale == scale
+ assert col.display_size == dsize
+ assert col.internal_size == isize
+
+ def test_pickle(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ description = curs.description
+ pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+ unpickled = pickle.loads(pickled)
+ assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
+
+ @pytest.mark.crdb_skip("no col query")
+ def test_no_col_query(self, conn):
+ cur = conn.execute("select")
+ assert cur.description == []
+ assert cur.fetchall() == [()]
+
+ def test_description_closed_connection(self, conn):
+ # If we have reasons to break this test we will (e.g. we really need
+ # the connection). In #172 it fails just by accident.
+ cur = conn.execute("select 1::int4 as foo")
+ conn.close()
+ assert len(cur.description) == 1
+ col = cur.description[0]
+ assert col.name == "foo"
+ assert col.type_code == 23
+
+ def test_name_not_a_name(self, conn):
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "foo-bar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_name_encode(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "\u20ac"
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ assert "psycopg.ClientCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_leak(conn_cls, dsn, faker, fetch, row_factory):
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ def work():
+ with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True):
+ with psycopg.ClientCursor(conn, row_factory=row_factory) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ cur.fetchall()
+ elif fetch == "iter":
+ for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select 'hello'", (), "select 'hello'"),
+ ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
+ ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
+ ("select %%", (), "select %%"),
+ ("select %%, %s", (["a"],), "select %, 'a'"),
+ ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
+ ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
+ ],
+)
+def test_mogrify(conn, query, params, want):
+ cur = conn.cursor()
+ got = cur.mogrify(query, *params)
+ assert got == want
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_mogrify_encoding(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ q = conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+ assert q == "select '\u20ac'"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_mogrify_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ with pytest.raises(UnicodeEncodeError):
+ conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+
+
+@pytest.mark.pipeline
+def test_message_0x33(conn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ conn.autocommit = True
+ with conn.pipeline():
+ cur = conn.execute("select 'test'")
+ assert cur.fetchone() == ("test",)
+
+ assert not notices
diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py
new file mode 100644
index 0000000..0cf8ec6
--- /dev/null
+++ b/tests/test_client_cursor_async.py
@@ -0,0 +1,727 @@
+import pytest
+import weakref
+import datetime as dt
+from typing import List
+
+import psycopg
+from psycopg import sql, rows
+from psycopg.adapt import PyFormat
+
+from .utils import alist, gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .test_cursor import execmany, _execmany # noqa: F401
+from .fix_crdb import crdb_encoding
+
+execmany = execmany # avoid F811 underneath
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def aconn(aconn):
+ aconn.cursor_factory = psycopg.AsyncClientCursor
+ return aconn
+
+
+async def test_init(aconn):
+ cur = psycopg.AsyncClientCursor(aconn)
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+ aconn.row_factory = rows.dict_row
+ cur = psycopg.AsyncClientCursor(aconn)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_init_factory(aconn):
+ cur = psycopg.AsyncClientCursor(aconn, row_factory=rows.dict_row)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_from_cursor_factory(aconn_cls, dsn):
+ async with await aconn_cls.connect(
+ dsn, cursor_factory=psycopg.AsyncClientCursor
+ ) as aconn:
+ cur = aconn.cursor()
+ assert type(cur) is psycopg.AsyncClientCursor
+
+ await cur.execute("select %s", (1,))
+ assert await cur.fetchone() == (1,)
+ assert cur._query
+ assert cur._query.query == b"select 1"
+
+
+async def test_close(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.execute("select 'foo'")
+
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_context(aconn):
+ async with aconn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+async def test_weakref(aconn):
+ cur = aconn.cursor()
+ w = weakref.ref(cur)
+ await cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_statusmessage(aconn):
+ cur = aconn.cursor()
+ assert cur.statusmessage is None
+
+ await cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ await cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+async def test_execute_sql(aconn):
+ cur = aconn.cursor()
+ await cur.execute(sql.SQL("select {value}").format(value="hello"))
+ assert await cur.fetchone() == ("hello",)
+
+
+async def test_execute_many_results(aconn):
+ cur = aconn.cursor()
+ assert cur.nextset() is None
+
+ rv = await cur.execute("select %s; select generate_series(1,%s)", ("foo", 3))
+ assert rv is cur
+ assert (await cur.fetchall()) == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert (await cur.fetchall()) == [(1,), (2,), (3,)]
+ assert cur.rowcount == 3
+ assert cur.nextset() is None
+
+ await cur.close()
+ assert cur.nextset() is None
+
+
+async def test_execute_sequence(aconn):
+ cur = aconn.cursor()
+ rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+async def test_execute_empty_query(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_execute_type_change(aconn):
+ # issue #112
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.execute(sql, (1,))
+ await cur.execute(sql, (100_000,))
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+async def test_executemany_type_change(aconn):
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.executemany(sql, [(1,), (100_000,)])
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+async def test_execute_copy(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute(query)
+
+
+async def test_fetchone(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = await cur.fetchone()
+ assert row == (1, "foo", None)
+ row = await cur.fetchone()
+ assert row is None
+
+
+async def test_binary_cursor_execute(aconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None])
+
+
+async def test_execute_binary(aconn):
+ with pytest.raises(psycopg.NotSupportedError):
+ cur = aconn.cursor()
+ await cur.execute("select %s, %s", [1, None], binary=True)
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None], binary=False)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_query_encode(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ await cur.execute("select '\u20ac'")
+ (res,) = await cur.fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_query_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ await cur.execute("select '\u20ac'")
+
+
+async def test_executemany(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(10, "hello"), (20, "world")]
+
+
+async def test_executemany_name(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(11, "hello"), (21, "world")]
+
+
+async def test_executemany_no_data(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+async def test_executemany_rowcount(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+async def test_executemany_returning(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_returning_discard(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_result(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+async def test_executemany_rowcount_no_hit(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ await cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ await cur.executemany(
+ "delete from execmany where id = %s returning num", [(-1,), (-2,)]
+ )
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ # This fails because we end up trying to copy in pipeline mode.
+ # However, sometimes (and pretty regularly if we enable pgconn.trace())
+ # something goes in a loop and only terminates by OOM. Strace shows
+ # an allocation loop. I think it's in the libpq.
+ # "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+async def test_executemany_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+async def test_executemany_null_first(aconn, fmt_in):
+ cur = aconn.cursor()
+ await cur.execute("create table testmany (a bigint, b bigint)")
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+async def test_rowcount(aconn):
+ cur = aconn.cursor()
+
+ await cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ await cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ await cur.execute(
+ "insert into test_rowcount_notuples select generate_series(1, 42)"
+ )
+ assert cur.rowcount == 42
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor()
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ async for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(await cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+async def test_iter(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ res = []
+ async for rec in cur:
+ res.append(rec)
+ assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ async for rec in cur:
+ assert rec == (1,)
+ break
+
+ async for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert (await cur.fetchone()) == (3,)
+ async for rec in cur:
+ assert False
+
+
+async def test_row_factory(aconn):
+ cur = aconn.cursor(row_factory=my_row_factory)
+ await cur.execute("select 'foo' as bar")
+ (r,) = await cur.fetchone()
+ assert r == "FOObar"
+
+ await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert await cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert await cur.fetchall() == [["Yy", "Zz"]]
+
+ await cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"y": "y", "z": "z"}
+
+
+async def test_row_factory_none(aconn):
+ cur = aconn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ await cur.execute("select 1 as a, 2 as b")
+ r = await cur.fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+async def test_bad_row_factory(aconn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = aconn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ await cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = aconn.cursor(row_factory=broken_maker)
+ await cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ await cur.fetchone()
+
+
+async def test_scroll(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.scroll(0)
+
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-1)
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-2)
+ assert await cur.fetchone() == (7,)
+ await cur.scroll(2, mode="absolute")
+ assert await cur.fetchone() == (2,)
+
+ # on the boundary
+ await cur.scroll(0, mode="absolute")
+ assert await cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ await cur.scroll(-1, mode="absolute")
+
+ await cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(-1)
+
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ await cur.scroll(10, mode="absolute")
+
+ await cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(1, "wat")
+
+
+async def test_query_params_execute(aconn):
+ cur = aconn.cursor()
+ assert cur._query is None
+
+ await cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select 1, NULL::text"
+ assert cur._query.params == (b"1", b"NULL")
+
+ await cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ await cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select 'wat'::int"
+ assert cur._query.params == (b"'wat'",)
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select %(x)s", {"x": 1}, (1,)),
+ ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)),
+ ("select %(x)s, %(x)s", {"x": 1}, (1, 1)),
+ ],
+)
+async def test_query_params_named(aconn, query, params, want):
+ cur = aconn.cursor()
+ await cur.execute(query, params)
+ rec = await cur.fetchone()
+ assert rec == want
+
+
+async def test_query_params_executemany(aconn):
+ cur = aconn.cursor()
+
+ await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select 3, 4"
+ assert cur._query.params == (b"3", b"4")
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_stream(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ assert "psycopg.AsyncClientCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ await cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ await cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_leak(aconn_cls, dsn, faker, fetch, row_factory):
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn, conn.transaction(
+ force_rollback=True
+ ):
+ async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+ await cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = await cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = await cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ await cur.fetchall()
+ elif fetch == "iter":
+ async for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ ("select 'hello'", (), "select 'hello'"),
+ ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
+ ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
+ ("select %%", (), "select %%"),
+ ("select %%, %s", (["a"],), "select %, 'a'"),
+ ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
+ ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
+ ],
+)
+async def test_mogrify(aconn, query, params, want):
+ cur = aconn.cursor()
+ got = cur.mogrify(query, *params)
+ assert got == want
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_mogrify_encoding(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+ assert q == "select '\u20ac'"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_mogrify_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ with pytest.raises(UnicodeEncodeError):
+ aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
+
+
+@pytest.mark.pipeline
+async def test_message_0x33(aconn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ cur = await aconn.execute("select 'test'")
+ assert (await cur.fetchone()) == ("test",)
+
+ assert not notices
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644
index 0000000..eec24f1
--- /dev/null
+++ b/tests/test_concurrency.py
@@ -0,0 +1,327 @@
+"""
+Tests dealing with concurrency issues.
+"""
+
+import os
+import sys
+import time
+import queue
+import signal
+import threading
+import multiprocessing
+import subprocess as sp
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import errors as e
+
+
+@pytest.mark.slow
+def test_concurrent_execution(conn_cls, dsn):
+ def worker():
+ cnn = conn_cls.connect(dsn)
+ cur = cnn.cursor()
+ cur.execute("select pg_sleep(0.5)")
+ cur.close()
+ cnn.close()
+
+ t1 = threading.Thread(target=worker)
+ t2 = threading.Thread(target=worker)
+ t0 = time.time()
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+ assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+def test_commit_concurrency(conn):
+ # Check the condition reported in psycopg2#103
+ # Because of bad status check, we commit even when a commit is already on
+ # its way. We can detect this condition by the warnings.
+ notices = queue.Queue() # type: ignore[var-annotated]
+ conn.add_notice_handler(lambda diag: notices.put(diag.message_primary))
+ stop = False
+
+ def committer():
+ nonlocal stop
+ while not stop:
+ conn.commit()
+
+ cur = conn.cursor()
+ t1 = threading.Thread(target=committer)
+ t1.start()
+ for i in range(1000):
+ cur.execute("select %s;", (i,))
+ conn.commit()
+
+ # Stop the committer thread
+ stop = True
+ t1.join()
+
+ assert notices.empty(), "%d notices raised" % notices.qsize()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+def test_multiprocess_close(dsn, tmpdir):
+ # Check the problem reported in psycopg2#829
+ # Subprocess gcs the copy of the fd after fork so it closes connection.
+ module = f"""\
+import time
+import psycopg
+
+def thread():
+ conn = psycopg.connect({dsn!r})
+ curs = conn.cursor()
+ for i in range(10):
+ curs.execute("select 1")
+ time.sleep(0.1)
+
+def process():
+ time.sleep(0.2)
+"""
+
+ script = """\
+import time
+import threading
+import multiprocessing
+import mptest
+
+t = threading.Thread(target=mptest.thread, name='mythread')
+t.start()
+time.sleep(0.2)
+multiprocessing.Process(target=mptest.process, name='myprocess').start()
+t.join()
+"""
+
+ with (tmpdir / "mptest.py").open("w") as f:
+ f.write(module)
+ env = dict(os.environ)
+ env["PYTHONPATH"] = str(tmpdir + os.pathsep + env.get("PYTHONPATH", ""))
+ out = sp.check_output(
+ [sys.executable, "-c", script], stderr=sp.STDOUT, env=env
+ ).decode("utf8", "replace")
+ assert out == "", out.strip().splitlines()[-1]
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("notify")
+def test_notifies(conn_cls, conn, dsn):
+ nconn = conn_cls.connect(dsn, autocommit=True)
+ npid = nconn.pgconn.backend_pid
+
+ def notifier():
+ time.sleep(0.25)
+ nconn.cursor().execute("notify foo, '1'")
+ time.sleep(0.25)
+ nconn.cursor().execute("notify foo, '2'")
+ nconn.close()
+
+ conn.autocommit = True
+ conn.cursor().execute("listen foo")
+
+ t0 = time.time()
+ t = threading.Thread(target=notifier)
+ t.start()
+
+ ns = []
+ gen = conn.notifies()
+ for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ gen.close()
+
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert isinstance(n, psycopg.Notify)
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+ t.join()
+
+
+def canceller(conn, errors):
+ try:
+ time.sleep(0.5)
+ conn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+def test_cancel(conn):
+ errors: List[Exception] = []
+
+ cur = conn.cursor()
+ t = threading.Thread(target=canceller, args=(conn, errors))
+ t0 = time.time()
+ t.start()
+
+ with pytest.raises(e.QueryCanceled):
+ cur.execute("select pg_sleep(2)")
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ conn.rollback()
+ assert cur.execute("select 1").fetchone()[0] == 1
+
+ t.join()
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+def test_cancel_stream(conn):
+ errors: List[Exception] = []
+
+ cur = conn.cursor()
+ t = threading.Thread(target=canceller, args=(conn, errors))
+ t0 = time.time()
+ t.start()
+
+ with pytest.raises(e.QueryCanceled):
+ for row in cur.stream("select pg_sleep(2)"):
+ pass
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ conn.rollback()
+ assert cur.execute("select 1").fetchone()[0] == 1
+
+ t.join()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+@pytest.mark.slow
+def test_identify_closure(conn_cls, dsn):
+ def closer():
+ time.sleep(0.2)
+ conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
+
+ conn = conn_cls.connect(dsn)
+ conn2 = conn_cls.connect(dsn)
+ try:
+ t = threading.Thread(target=closer)
+ t.start()
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_sleep(1.0)")
+ t1 = time.time()
+ assert 0.2 < t1 - t0 < 0.4
+ finally:
+ t.join()
+ finally:
+ conn.close()
+ conn2.close()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb_skip("cancel")
+def test_ctrl_c(dsn):
+ if sys.platform == "win32":
+ sig = int(signal.CTRL_C_EVENT)
+ # Or pytest will receive the Ctrl-C too
+ creationflags = sp.CREATE_NEW_PROCESS_GROUP
+ else:
+ sig = int(signal.SIGINT)
+ creationflags = 0
+
+ script = f"""\
+import os
+import time
+import psycopg
+from threading import Thread
+
+def tired_of_life():
+ time.sleep(1)
+ os.kill(os.getpid(), {sig!r})
+
+t = Thread(target=tired_of_life, daemon=True)
+t.start()
+
+with psycopg.connect({dsn!r}) as conn:
+ cur = conn.cursor()
+ ctrl_c = False
+ try:
+ cur.execute("select pg_sleep(2)")
+ except KeyboardInterrupt:
+ ctrl_c = True
+
+ assert ctrl_c, "ctrl-c not received"
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ conn.rollback()
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+"""
+ t0 = time.time()
+ proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags)
+ proc.communicate()
+ t = time.time() - t0
+ assert proc.returncode == 0
+ assert 1 < t < 2
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ multiprocessing.get_all_start_methods()[0] != "fork",
+ reason="problematic behavior only exhibited via fork",
+)
+def test_segfault_on_fork_close(dsn):
+ # https://github.com/psycopg/psycopg/issues/300
+ script = f"""\
+import gc
+import psycopg
+from multiprocessing import Pool
+
+def test(arg):
+ conn1 = psycopg.connect({dsn!r})
+ conn1.close()
+ conn1 = None
+ gc.collect()
+ return 1
+
+if __name__ == '__main__':
+ conn = psycopg.connect({dsn!r})
+ with Pool(2) as p:
+ pool_result = p.map_async(test, [1, 2])
+ pool_result.wait(timeout=5)
+ if pool_result.ready():
+ print(pool_result.get(timeout=1))
+"""
+ env = dict(os.environ)
+ env["PYTHONFAULTHANDLER"] = "1"
+ out = sp.check_output([sys.executable, "-s", "-c", script], env=env)
+ assert out.decode().rstrip() == "[1, 1]"
diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py
new file mode 100644
index 0000000..29b08cf
--- /dev/null
+++ b/tests/test_concurrency_async.py
@@ -0,0 +1,242 @@
+import sys
+import time
+import signal
+import asyncio
+import subprocess as sp
+from asyncio.queues import Queue
+from typing import List, Tuple
+
+import pytest
+
+import psycopg
+from psycopg import errors as e
+from psycopg._compat import create_task
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.slow
+async def test_commit_concurrency(aconn):
+ # Check the condition reported in psycopg2#103
+ # Because of bad status check, we commit even when a commit is already on
+ # its way. We can detect this condition by the warnings.
+ notices = Queue() # type: ignore[var-annotated]
+ aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary))
+ stop = False
+
+ async def committer():
+ nonlocal stop
+ while not stop:
+ await aconn.commit()
+ await asyncio.sleep(0) # Allow the other worker to work
+
+ async def runner():
+ nonlocal stop
+ cur = aconn.cursor()
+ for i in range(1000):
+ await cur.execute("select %s;", (i,))
+ await aconn.commit()
+
+ # Stop the committer thread
+ stop = True
+
+ await asyncio.gather(committer(), runner())
+ assert notices.empty(), "%d notices raised" % notices.qsize()
+
+
+@pytest.mark.slow
+async def test_concurrent_execution(aconn_cls, dsn):
+ async def worker():
+ cnn = await aconn_cls.connect(dsn)
+ cur = cnn.cursor()
+ await cur.execute("select pg_sleep(0.5)")
+ await cur.close()
+ await cnn.close()
+
+ workers = [worker(), worker()]
+ t0 = time.time()
+ await asyncio.gather(*workers)
+ assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("notify")
+async def test_notifies(aconn_cls, aconn, dsn):
+ nconn = await aconn_cls.connect(dsn, autocommit=True)
+ npid = nconn.pgconn.backend_pid
+
+ async def notifier():
+ cur = nconn.cursor()
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '1'")
+ await asyncio.sleep(0.25)
+ await cur.execute("notify foo, '2'")
+ await nconn.close()
+
+ async def receiver():
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ gen = aconn.notifies()
+ async for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ await gen.aclose()
+
+ ns: List[Tuple[psycopg.Notify, float]] = []
+ t0 = time.time()
+ workers = [notifier(), receiver()]
+ await asyncio.gather(*workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+async def canceller(aconn, errors):
+ try:
+ await asyncio.sleep(0.5)
+ aconn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ await cur.execute("select pg_sleep(2)")
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("cancel")
+async def test_cancel_stream(aconn):
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(e.QueryCanceled):
+ async for row in cur.stream("select pg_sleep(2)"):
+ pass
+
+ errors: List[Exception] = []
+ workers = [worker(), canceller(aconn, errors)]
+
+ t0 = time.time()
+ await asyncio.gather(*workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+@pytest.mark.slow
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_identify_closure(aconn_cls, dsn):
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute(
+ "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid]
+ )
+
+ aconn = await aconn_cls.connect(dsn)
+ conn2 = await aconn_cls.connect(dsn)
+ try:
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("select pg_sleep(1.0)")
+ t1 = time.time()
+ assert 0.2 < t1 - t0 < 0.4
+ finally:
+ await asyncio.gather(t)
+ finally:
+ await aconn.close()
+ await conn2.close()
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb_skip("cancel")
+async def test_ctrl_c(dsn):
+ script = f"""\
+import signal
+import asyncio
+import psycopg
+
+async def main():
+ ctrl_c = False
+ loop = asyncio.get_event_loop()
+ async with await psycopg.AsyncConnection.connect({dsn!r}) as conn:
+ loop.add_signal_handler(signal.SIGINT, conn.cancel)
+ cur = conn.cursor()
+ try:
+ await cur.execute("select pg_sleep(2)")
+ except psycopg.errors.QueryCanceled:
+ ctrl_c = True
+
+ assert ctrl_c, "ctrl-c not received"
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await conn.rollback()
+ assert (
+ conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE
+ ), f"transaction status: {{conn.info.transaction_status!r}}"
+
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+asyncio.run(main())
+"""
+ if sys.platform == "win32":
+ creationflags = sp.CREATE_NEW_PROCESS_GROUP
+ sig = signal.CTRL_C_EVENT
+ else:
+ creationflags = 0
+ sig = signal.SIGINT
+
+ proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags)
+ with pytest.raises(sp.TimeoutExpired):
+ outs, errs = proc.communicate(timeout=1)
+
+ proc.send_signal(sig)
+ proc.communicate()
+ assert proc.returncode == 0
diff --git a/tests/test_connection.py b/tests/test_connection.py
new file mode 100644
index 0000000..57c6c78
--- /dev/null
+++ b/tests/test_connection.py
@@ -0,0 +1,790 @@
+import time
+import pytest
+import logging
+import weakref
+from typing import Any, List
+from dataclasses import dataclass
+
+import psycopg
+from psycopg import Notify, errors as e
+from psycopg.rows import tuple_row
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
+
+from .utils import gc_collect
+from .test_cursor import my_row_factory
+from .test_adapt import make_bin_dumper, make_dumper
+
+
+def test_connect(conn_cls, dsn):
+ conn = conn_cls.connect(dsn)
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ conn.close()
+
+
+def test_connect_str_subclass(conn_cls, dsn):
+ class MyString(str):
+ pass
+
+ conn = conn_cls.connect(MyString(dsn))
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ conn.close()
+
+
+def test_connect_bad(conn_cls):
+ with pytest.raises(psycopg.OperationalError):
+ conn_cls.connect("dbname=nosuchdb")
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_connect_timeout(conn_cls, deaf_port):
+ t0 = time.time()
+ with pytest.raises(psycopg.OperationalError, match="timeout expired"):
+ conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+ elapsed = time.time() - t0
+ assert elapsed == pytest.approx(1.0, abs=0.05)
+
+
+def test_close(conn):
+ assert not conn.closed
+ assert not conn.broken
+
+ cur = conn.cursor()
+
+ conn.close()
+ assert conn.closed
+ assert not conn.broken
+ assert conn.pgconn.status == conn.ConnStatus.BAD
+
+ conn.close()
+ assert conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.BAD
+
+ with pytest.raises(psycopg.OperationalError):
+ cur.execute("select 1")
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_broken(conn):
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
+ assert conn.closed
+ assert conn.broken
+ conn.close()
+ assert conn.closed
+ assert conn.broken
+
+
+def test_cursor_closed(conn):
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ with conn.cursor("foo"):
+ pass
+ with pytest.raises(psycopg.OperationalError):
+ conn.cursor()
+
+
+def test_connection_warn_close(conn_cls, dsn, recwarn):
+ conn = conn_cls.connect(dsn)
+ conn.close()
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+ conn = conn_cls.connect(dsn)
+ del conn
+ assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = conn_cls.connect(dsn)
+ conn.execute("select 1")
+ del conn
+ assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = conn_cls.connect(dsn)
+ try:
+ conn.execute("select wat")
+ except Exception:
+ pass
+ del conn
+ assert "INERROR" in str(recwarn.pop(ResourceWarning).message)
+
+ with conn_cls.connect(dsn) as conn:
+ pass
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.fixture
+def testctx(svcconn):
+ svcconn.execute("create table if not exists testctx (id int primary key)")
+ svcconn.execute("delete from testctx")
+ return None
+
+
+def test_context_commit(conn_cls, testctx, conn, dsn):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("insert into testctx values (42)")
+
+ assert conn.closed
+ assert not conn.broken
+
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ cur.execute("select * from testctx")
+ assert cur.fetchall() == [(42,)]
+
+
+def test_context_rollback(conn_cls, testctx, conn, dsn):
+ with pytest.raises(ZeroDivisionError):
+ with conn:
+ with conn.cursor() as cur:
+ cur.execute("insert into testctx values (42)")
+ 1 / 0
+
+ assert conn.closed
+ assert not conn.broken
+
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor() as cur:
+ cur.execute("select * from testctx")
+ assert cur.fetchall() == []
+
+
+def test_context_close(conn):
+ with conn:
+ conn.execute("select 1")
+ conn.close()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_context_inerror_rollback_no_clobber(conn_cls, conn, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn2:
+ conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn:
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ assert not conn.pgconn.error_message
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.slow
+def test_weakref(conn_cls, dsn):
+ conn = conn_cls.connect(dsn)
+ w = weakref.ref(conn)
+ conn.close()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+def test_commit(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.commit()
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_commit_error(conn):
+ conn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ conn.commit()
+
+ conn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+
+def test_rollback(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.rollback()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.ntuples == 0
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.rollback()
+
+
+def test_auto_transaction(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("meh")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+ with pytest.raises(psycopg.errors.InFailedSqlTransaction):
+ cur.execute("select 1")
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() is None
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_autocommit(conn):
+ assert conn.autocommit is False
+ conn.autocommit = True
+ assert conn.autocommit
+ cur = conn.cursor()
+ assert cur.execute("select 1").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ conn.autocommit = ""
+ assert conn.autocommit is False # type: ignore[comparison-overlap]
+ conn.autocommit = "yeah"
+ assert conn.autocommit is True
+
+
+def test_autocommit_connect(conn_cls, dsn):
+ conn = conn_cls.connect(dsn, autocommit=True)
+ assert conn.autocommit
+ conn.close()
+
+
+def test_autocommit_intrans(conn):
+ cur = conn.cursor()
+ assert cur.execute("select 1").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+def test_autocommit_inerror(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("meh")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+def test_autocommit_unknown(conn):
+ conn.close()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN
+ with pytest.raises(psycopg.OperationalError):
+ conn.autocommit = True
+ assert not conn.autocommit
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("host=foo user=bar",), {}, "host=foo user=bar"),
+ (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (
+ ("host=foo port=5432",),
+ {"host": "qux", "user": "joe"},
+ "host=qux user=joe port=5432",
+ ),
+ (("host=foo",), {"user": None}, "host=foo"),
+ ],
+)
+def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = conn_cls.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ with pytest.raises(exctype):
+ conn_cls.connect(*args, **kwargs)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_broken_connection(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("select pg_terminate_backend(pg_backend_pid())")
+ assert conn.closed
+
+
+@pytest.mark.crdb_skip("do")
+def test_notice_handlers(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ messages = []
+ severities = []
+
+ def cb1(diag):
+ messages.append(diag.message_primary)
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ conn.add_notice_handler(cb1)
+ conn.add_notice_handler(cb2)
+ conn.add_notice_handler("the wrong thing")
+ conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
+
+ conn.pgconn.exec_(b"set client_min_messages to notice")
+ cur = conn.cursor()
+ cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ conn.remove_notice_handler(cb1)
+ conn.remove_notice_handler("the wrong thing")
+ cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql")
+ assert len(caplog.records) == 3
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE", "WARNING"]
+
+ with pytest.raises(ValueError):
+ conn.remove_notice_handler(cb1)
+
+
+@pytest.mark.crdb_skip("notify")
+def test_notify_handlers(conn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ conn.add_notify_handler(cb1)
+ conn.add_notify_handler(lambda n: nots2.append(n))
+
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("listen foo")
+ cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == conn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ conn.remove_notify_handler(cb1)
+ cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == conn.pgconn.backend_pid
+ assert hash(n)
+
+ with pytest.raises(ValueError):
+ conn.remove_notify_handler(cb1)
+
+
+def test_execute(conn):
+ cur = conn.execute("select %s, %s", [10, 20])
+ assert cur.fetchone() == (10, 20)
+ assert cur.format == 0
+ assert cur.pgresult.fformat(0) == 0
+
+ cur = conn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21})
+ assert cur.fetchone() == (11, 21)
+
+ cur = conn.execute("select 12, 22")
+ assert cur.fetchone() == (12, 22)
+
+
+def test_execute_binary(conn):
+ cur = conn.execute("select %s, %s", [10, 20], binary=True)
+ assert cur.fetchone() == (10, 20)
+ assert cur.format == 1
+ assert cur.pgresult.fformat(0) == 1
+
+
+def test_row_factory(conn_cls, dsn):
+ defaultconn = conn_cls.connect(dsn)
+ assert defaultconn.row_factory is tuple_row
+ defaultconn.close()
+
+ conn = conn_cls.connect(dsn, row_factory=my_row_factory)
+ assert conn.row_factory is my_row_factory
+
+ cur = conn.execute("select 'a' as ve")
+ assert cur.fetchone() == ["Ave"]
+
+ with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1:
+ cur1.execute("select 1, 1, 2")
+ assert cur1.fetchall() == [{1, 2}]
+
+ with conn.cursor(row_factory=tuple_row) as cur2:
+ cur2.execute("select 1, 1, 2")
+ assert cur2.fetchall() == [(1, 1, 2)]
+
+ # TODO: maybe fix something to get rid of 'type: ignore' below.
+ conn.row_factory = tuple_row
+ cur3 = conn.execute("select 'vale'")
+ r = cur3.fetchone()
+ assert r and r == ("vale",)
+ conn.close()
+
+
+def test_str(conn):
+ assert "[IDLE]" in str(conn)
+ conn.close()
+ assert "[BAD]" in str(conn)
+
+
+def test_fileno(conn):
+ assert conn.fileno() == conn.pgconn.socket
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.fileno()
+
+
+def test_cursor_factory(conn):
+ assert conn.cursor_factory is psycopg.Cursor
+
+ class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
+ pass
+
+ conn.cursor_factory = MyCursor
+ with conn.cursor() as cur:
+ assert isinstance(cur, MyCursor)
+
+ with conn.execute("select 1") as cur:
+ assert isinstance(cur, MyCursor)
+
+
+def test_cursor_factory_connect(conn_cls, dsn):
+ class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
+ pass
+
+ with conn_cls.connect(dsn, cursor_factory=MyCursor) as conn:
+ assert conn.cursor_factory is MyCursor
+ cur = conn.cursor()
+ assert type(cur) is MyCursor
+
+
+def test_server_cursor_factory(conn):
+ assert conn.server_cursor_factory is psycopg.ServerCursor
+
+ class MyServerCursor(psycopg.ServerCursor[psycopg.rows.Row]):
+ pass
+
+ conn.server_cursor_factory = MyServerCursor
+ with conn.cursor(name="n") as cur:
+ assert isinstance(cur, MyServerCursor)
+
+
+@dataclass
+class ParamDef:
+ name: str
+ guc: str
+ values: List[Any]
+
+
+param_isolation = ParamDef(
+ name="isolation_level",
+ guc="isolation",
+ values=list(psycopg.IsolationLevel),
+)
+param_read_only = ParamDef(
+ name="read_only",
+ guc="read_only",
+ values=[True, False],
+)
+param_deferrable = ParamDef(
+ name="deferrable",
+ guc="deferrable",
+ values=[True, False],
+)
+
+# Map Python values to Postgres values for the tx_params possible values
+tx_values_map = {
+ v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel
+}
+tx_values_map["on"] = True
+tx_values_map["off"] = False
+
+
+tx_params = [
+ param_isolation,
+ param_read_only,
+ pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")),
+]
+tx_params_isolation = [
+ pytest.param(
+ param_isolation,
+ id="isolation_level",
+ marks=pytest.mark.crdb("skip", reason="transaction isolation"),
+ ),
+ pytest.param(
+ param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only")
+ ),
+ pytest.param(
+ param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable")
+ ),
+]
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_transaction_param_default(conn, param):
+ assert getattr(conn, param.name) is None
+ current, default = conn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ ).fetchone()
+ assert current == default
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+def test_set_transaction_param_implicit(conn, param, autocommit):
+ conn.autocommit = autocommit
+ for value in param.values:
+ setattr(conn, param.name, value)
+ pgval, default = conn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ ).fetchone()
+ if autocommit:
+ assert pgval == default
+ else:
+ assert tx_values_map[pgval] == value
+ conn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+def test_set_transaction_param_block(conn, param, autocommit):
+ conn.autocommit = autocommit
+ for value in param.values:
+ setattr(conn, param.name, value)
+ with conn.transaction():
+ pgval = conn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ ).fetchone()[0]
+ assert tx_values_map[pgval] == value
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_implicit(conn, param):
+ conn.execute("select 1")
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_block(conn, param):
+ with conn.transaction():
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.parametrize("param", tx_params)
+def test_set_transaction_param_not_intrans_external(conn, param):
+ conn.autocommit = True
+ conn.execute("begin")
+ with pytest.raises(psycopg.ProgrammingError):
+ setattr(conn, param.name, param.values[0])
+
+
+@pytest.mark.crdb("skip", reason="transaction isolation")
+def test_set_transaction_param_all(conn):
+ params: List[Any] = tx_params[:]
+ params[2] = params[2].values[0]
+
+ for param in params:
+ value = param.values[0]
+ setattr(conn, param.name, value)
+
+ for param in params:
+ pgval = conn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ ).fetchone()[0]
+ assert tx_values_map[pgval] == value
+
+
+def test_set_transaction_param_strange(conn):
+ for val in ("asdf", 0, 5):
+ with pytest.raises(ValueError):
+ conn.isolation_level = val
+
+ conn.isolation_level = psycopg.IsolationLevel.SERIALIZABLE.value
+ assert conn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE
+
+ conn.read_only = 1
+ assert conn.read_only is True
+
+ conn.deferrable = 0
+ assert conn.deferrable is False
+
+
+conninfo_params_timeout = [
+ (
+ "",
+ {"dbname": "mydb", "connect_timeout": None},
+ ({"dbname": "mydb"}, None),
+ ),
+ (
+ "",
+ {"dbname": "mydb", "connect_timeout": 1},
+ ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+ ),
+ (
+ "dbname=postgres",
+ {},
+ ({"dbname": "postgres"}, None),
+ ),
+ (
+ "dbname=postgres connect_timeout=2",
+ {},
+ ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+ ),
+ (
+ "postgresql:///postgres?connect_timeout=2",
+ {"connect_timeout": 10},
+ ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+ ),
+]
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+def test_get_connection_params(conn_cls, dsn, kwargs, exp):
+ params = conn_cls._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params.get("connect_timeout") == exp[1]
+
+
+def test_connect_context(conn_cls, dsn):
+ ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+ ctx.register_dumper(str, make_bin_dumper("b"))
+ ctx.register_dumper(str, make_dumper("t"))
+
+ conn = conn_cls.connect(dsn, context=ctx)
+
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellot"
+ cur = conn.execute("select %b", ["hello"])
+ assert cur.fetchone()[0] == "hellob"
+ conn.close()
+
+
+def test_connect_context_copy(conn_cls, dsn, conn):
+ conn.adapters.register_dumper(str, make_bin_dumper("b"))
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ conn2 = conn_cls.connect(dsn, context=conn)
+
+ cur = conn2.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellot"
+ cur = conn2.execute("select %b", ["hello"])
+ assert cur.fetchone()[0] == "hellob"
+ conn2.close()
+
+
+def test_cancel_closed(conn):
+ conn.close()
+ conn.cancel()
diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py
new file mode 100644
index 0000000..1288a6c
--- /dev/null
+++ b/tests/test_connection_async.py
@@ -0,0 +1,751 @@
+import time
+import pytest
+import logging
+import weakref
+from typing import List, Any
+
+import psycopg
+from psycopg import Notify, errors as e
+from psycopg.rows import tuple_row
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
+
+from .utils import gc_collect
+from .test_cursor import my_row_factory
+from .test_connection import tx_params, tx_params_isolation, tx_values_map
+from .test_connection import conninfo_params_timeout
+from .test_connection import testctx # noqa: F401 # fixture
+from .test_adapt import make_bin_dumper, make_dumper
+from .test_conninfo import fake_resolve # noqa: F401
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_connect(aconn_cls, dsn):
+ conn = await aconn_cls.connect(dsn)
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ await conn.close()
+
+
+async def test_connect_bad(aconn_cls):
+ with pytest.raises(psycopg.OperationalError):
+ await aconn_cls.connect("dbname=nosuchdb")
+
+
+async def test_connect_str_subclass(aconn_cls, dsn):
+ class MyString(str):
+ pass
+
+ conn = await aconn_cls.connect(MyString(dsn))
+ assert not conn.closed
+ assert conn.pgconn.status == conn.ConnStatus.OK
+ await conn.close()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_connect_timeout(aconn_cls, deaf_port):
+ t0 = time.time()
+ with pytest.raises(psycopg.OperationalError, match="timeout expired"):
+ await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+ elapsed = time.time() - t0
+ assert elapsed == pytest.approx(1.0, abs=0.05)
+
+
+async def test_close(aconn):
+ assert not aconn.closed
+ assert not aconn.broken
+
+ cur = aconn.cursor()
+
+ await aconn.close()
+ assert aconn.closed
+ assert not aconn.broken
+ assert aconn.pgconn.status == aconn.ConnStatus.BAD
+
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.pgconn.status == aconn.ConnStatus.BAD
+
+ with pytest.raises(psycopg.OperationalError):
+ await cur.execute("select 1")
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_broken(aconn):
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute(
+ "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid]
+ )
+ assert aconn.closed
+ assert aconn.broken
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.broken
+
+
+async def test_cursor_closed(aconn):
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ async with aconn.cursor("foo"):
+ pass
+ aconn.cursor("foo")
+ with pytest.raises(psycopg.OperationalError):
+ aconn.cursor()
+
+
+async def test_connection_warn_close(aconn_cls, dsn, recwarn):
+ conn = await aconn_cls.connect(dsn)
+ await conn.close()
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+ conn = await aconn_cls.connect(dsn)
+ del conn
+ assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = await aconn_cls.connect(dsn)
+ await conn.execute("select 1")
+ del conn
+ assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
+
+ conn = await aconn_cls.connect(dsn)
+ try:
+ await conn.execute("select wat")
+ except Exception:
+ pass
+ del conn
+ assert "INERROR" in str(recwarn.pop(ResourceWarning).message)
+
+ async with await aconn_cls.connect(dsn) as conn:
+ pass
+ del conn
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.mark.usefixtures("testctx")
+async def test_context_commit(aconn_cls, aconn, dsn):
+ async with aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("insert into testctx values (42)")
+
+ assert aconn.closed
+ assert not aconn.broken
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("select * from testctx")
+ assert await cur.fetchall() == [(42,)]
+
+
+@pytest.mark.usefixtures("testctx")
+async def test_context_rollback(aconn_cls, aconn, dsn):
+ with pytest.raises(ZeroDivisionError):
+ async with aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("insert into testctx values (42)")
+ 1 / 0
+
+ assert aconn.closed
+ assert not aconn.broken
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ async with aconn.cursor() as cur:
+ await cur.execute("select * from testctx")
+ assert await cur.fetchall() == []
+
+
+async def test_context_close(aconn):
+ async with aconn:
+ await aconn.execute("select 1")
+ await aconn.close()
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog):
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn2:
+ await conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn:
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ assert not conn.pgconn.error_message
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.slow
+async def test_weakref(aconn_cls, dsn):
+ conn = await aconn_cls.connect(dsn)
+ w = weakref.ref(conn)
+ await conn.close()
+ del conn
+ gc_collect()
+ assert w() is None
+
+
+async def test_commit(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ aconn.pgconn.exec_(b"insert into foo values (1)")
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.commit()
+
+
+@pytest.mark.crdb_skip("deferrable")
+async def test_commit_error(aconn):
+ await aconn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ await aconn.commit()
+
+ await aconn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ cur = await aconn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+
+async def test_rollback(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ aconn.pgconn.exec_(b"insert into foo values (1)")
+ await aconn.rollback()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.ntuples == 0
+
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.rollback()
+
+
+async def test_auto_transaction(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await cur.execute("insert into foo values (1)")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ await cur.execute("select * from foo")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_auto_transaction_fail(aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await cur.execute("insert into foo values (1)")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("meh")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+ await aconn.commit()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ await cur.execute("select * from foo")
+ assert await cur.fetchone() is None
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_autocommit(aconn):
+ assert aconn.autocommit is False
+ with pytest.raises(AttributeError):
+ aconn.autocommit = True
+ assert not aconn.autocommit
+
+ await aconn.set_autocommit(True)
+ assert aconn.autocommit
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ await aconn.set_autocommit("")
+ assert aconn.autocommit is False
+ await aconn.set_autocommit("yeah")
+ assert aconn.autocommit is True
+
+
+async def test_autocommit_connect(aconn_cls, dsn):
+ aconn = await aconn_cls.connect(dsn, autocommit=True)
+ assert aconn.autocommit
+ await aconn.close()
+
+
+async def test_autocommit_intrans(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+ with pytest.raises(psycopg.ProgrammingError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+async def test_autocommit_inerror(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("meh")
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+ with pytest.raises(psycopg.ProgrammingError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+async def test_autocommit_unknown(aconn):
+ await aconn.close()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.set_autocommit(True)
+ assert not aconn.autocommit
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
+ (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
+ (
+ ("dbname=foo port=5432",),
+ {"dbname": "qux", "user": "joe"},
+ "dbname=qux user=joe port=5432",
+ ),
+ (("dbname=foo",), {"user": None}, "dbname=foo"),
+ ],
+)
+async def test_connect_args(
+ aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
+):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ setpgenv({})
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = await aconn_cls.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ await conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ with pytest.raises(exctype):
+ await aconn_cls.connect(*args, **kwargs)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_broken_connection(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("select pg_terminate_backend(pg_backend_pid())")
+ assert aconn.closed
+
+
+@pytest.mark.crdb_skip("do")
+async def test_notice_handlers(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ messages = []
+ severities = []
+
+ def cb1(diag):
+ messages.append(diag.message_primary)
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ aconn.add_notice_handler(cb1)
+ aconn.add_notice_handler(cb2)
+ aconn.add_notice_handler("the wrong thing")
+ aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
+
+ aconn.pgconn.exec_(b"set client_min_messages to notice")
+ cur = aconn.cursor()
+ await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ aconn.remove_notice_handler(cb1)
+ aconn.remove_notice_handler("the wrong thing")
+ await cur.execute(
+ "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
+ )
+ assert len(caplog.records) == 3
+ assert messages == ["hello notice"]
+ assert severities == ["NOTICE", "WARNING"]
+
+ with pytest.raises(ValueError):
+ aconn.remove_notice_handler(cb1)
+
+
+@pytest.mark.crdb_skip("notify")
+async def test_notify_handlers(aconn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ aconn.add_notify_handler(cb1)
+ aconn.add_notify_handler(lambda n: nots2.append(n))
+
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ await cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ aconn.remove_notify_handler(cb1)
+ await cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ with pytest.raises(ValueError):
+ aconn.remove_notify_handler(cb1)
+
+
+async def test_execute(aconn):
+ cur = await aconn.execute("select %s, %s", [10, 20])
+ assert await cur.fetchone() == (10, 20)
+ assert cur.format == 0
+ assert cur.pgresult.fformat(0) == 0
+
+ cur = await aconn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21})
+ assert await cur.fetchone() == (11, 21)
+
+ cur = await aconn.execute("select 12, 22")
+ assert await cur.fetchone() == (12, 22)
+
+
+async def test_execute_binary(aconn):
+ cur = await aconn.execute("select %s, %s", [10, 20], binary=True)
+ assert await cur.fetchone() == (10, 20)
+ assert cur.format == 1
+ assert cur.pgresult.fformat(0) == 1
+
+
+async def test_row_factory(aconn_cls, dsn):
+ defaultconn = await aconn_cls.connect(dsn)
+ assert defaultconn.row_factory is tuple_row
+ await defaultconn.close()
+
+ conn = await aconn_cls.connect(dsn, row_factory=my_row_factory)
+ assert conn.row_factory is my_row_factory
+
+ cur = await conn.execute("select 'a' as ve")
+ assert await cur.fetchone() == ["Ave"]
+
+ async with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1:
+ await cur1.execute("select 1, 1, 2")
+ assert await cur1.fetchall() == [{1, 2}]
+
+ async with conn.cursor(row_factory=tuple_row) as cur2:
+ await cur2.execute("select 1, 1, 2")
+ assert await cur2.fetchall() == [(1, 1, 2)]
+
+ # TODO: maybe fix something to get rid of 'type: ignore' below.
+ conn.row_factory = tuple_row
+ cur3 = await conn.execute("select 'vale'")
+ r = await cur3.fetchone()
+ assert r and r == ("vale",)
+ await conn.close()
+
+
+async def test_str(aconn):
+ assert "[IDLE]" in str(aconn)
+ await aconn.close()
+ assert "[BAD]" in str(aconn)
+
+
+async def test_fileno(aconn):
+ assert aconn.fileno() == aconn.pgconn.socket
+ await aconn.close()
+ with pytest.raises(psycopg.OperationalError):
+ aconn.fileno()
+
+
+async def test_cursor_factory(aconn):
+ assert aconn.cursor_factory is psycopg.AsyncCursor
+
+ class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]):
+ pass
+
+ aconn.cursor_factory = MyCursor
+ async with aconn.cursor() as cur:
+ assert isinstance(cur, MyCursor)
+
+ async with (await aconn.execute("select 1")) as cur:
+ assert isinstance(cur, MyCursor)
+
+
+async def test_cursor_factory_connect(aconn_cls, dsn):
+ class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]):
+ pass
+
+ async with await aconn_cls.connect(dsn, cursor_factory=MyCursor) as conn:
+ assert conn.cursor_factory is MyCursor
+ cur = conn.cursor()
+ assert type(cur) is MyCursor
+
+
+async def test_server_cursor_factory(aconn):
+ assert aconn.server_cursor_factory is psycopg.AsyncServerCursor
+
+ class MyServerCursor(psycopg.AsyncServerCursor[psycopg.rows.Row]):
+ pass
+
+ aconn.server_cursor_factory = MyServerCursor
+ async with aconn.cursor(name="n") as cur:
+ assert isinstance(cur, MyServerCursor)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_transaction_param_default(aconn, param):
+ assert getattr(aconn, param.name) is None
+ cur = await aconn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ )
+ current, default = await cur.fetchone()
+ assert current == default
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_transaction_param_readonly_property(aconn, param):
+ with pytest.raises(AttributeError):
+ setattr(aconn, param.name, None)
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+async def test_set_transaction_param_implicit(aconn, param, autocommit):
+ await aconn.set_autocommit(autocommit)
+ for value in param.values:
+ await getattr(aconn, f"set_{param.name}")(value)
+ cur = await aconn.execute(
+ "select current_setting(%s), current_setting(%s)",
+ [f"transaction_{param.guc}", f"default_transaction_{param.guc}"],
+ )
+ pgval, default = await cur.fetchone()
+ if autocommit:
+ assert pgval == default
+ else:
+ assert tx_values_map[pgval] == value
+ await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [True, False])
+@pytest.mark.parametrize("param", tx_params_isolation)
+async def test_set_transaction_param_block(aconn, param, autocommit):
+ await aconn.set_autocommit(autocommit)
+ for value in param.values:
+ await getattr(aconn, f"set_{param.name}")(value)
+ async with aconn.transaction():
+ cur = await aconn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ )
+ pgval = (await cur.fetchone())[0]
+ assert tx_values_map[pgval] == value
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_implicit(aconn, param):
+ await aconn.execute("select 1")
+ value = param.values[0]
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_block(aconn, param):
+ value = param.values[0]
+ async with aconn.transaction():
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.parametrize("param", tx_params)
+async def test_set_transaction_param_not_intrans_external(aconn, param):
+ value = param.values[0]
+ await aconn.set_autocommit(True)
+ await aconn.execute("begin")
+ with pytest.raises(psycopg.ProgrammingError):
+ await getattr(aconn, f"set_{param.name}")(value)
+
+
+@pytest.mark.crdb("skip", reason="transaction isolation")
+async def test_set_transaction_param_all(aconn):
+ params: List[Any] = tx_params[:]
+ params[2] = params[2].values[0]
+
+ for param in params:
+ value = param.values[0]
+ await getattr(aconn, f"set_{param.name}")(value)
+
+ for param in params:
+ cur = await aconn.execute(
+ "select current_setting(%s)", [f"transaction_{param.guc}"]
+ )
+ pgval = (await cur.fetchone())[0]
+ assert tx_values_map[pgval] == value
+
+
+async def test_set_transaction_param_strange(aconn):
+ for val in ("asdf", 0, 5):
+ with pytest.raises(ValueError):
+ await aconn.set_isolation_level(val)
+
+ await aconn.set_isolation_level(psycopg.IsolationLevel.SERIALIZABLE.value)
+ assert aconn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE
+
+ await aconn.set_read_only(1)
+ assert aconn.read_only is True
+
+ await aconn.set_deferrable(0)
+ assert aconn.deferrable is False
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
+ setpgenv({})
+ params = await aconn_cls._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params["connect_timeout"] == exp[1]
+
+
+async def test_connect_context_adapters(aconn_cls, dsn):
+ ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+ ctx.register_dumper(str, make_bin_dumper("b"))
+ ctx.register_dumper(str, make_dumper("t"))
+
+ conn = await aconn_cls.connect(dsn, context=ctx)
+
+ cur = await conn.execute("select %s", ["hello"])
+ assert (await cur.fetchone())[0] == "hellot"
+ cur = await conn.execute("select %b", ["hello"])
+ assert (await cur.fetchone())[0] == "hellob"
+ await conn.close()
+
+
+async def test_connect_context_copy(aconn_cls, dsn, aconn):
+ aconn.adapters.register_dumper(str, make_bin_dumper("b"))
+ aconn.adapters.register_dumper(str, make_dumper("t"))
+
+ aconn2 = await aconn_cls.connect(dsn, context=aconn)
+
+ cur = await aconn2.execute("select %s", ["hello"])
+ assert (await cur.fetchone())[0] == "hellot"
+ cur = await aconn2.execute("select %b", ["hello"])
+ assert (await cur.fetchone())[0] == "hellob"
+ await aconn2.close()
+
+
+async def test_cancel_closed(aconn):
+ await aconn.close()
+ aconn.cancel()
+
+
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+
+ with pytest.raises(ZeroDivisionError):
+ await psycopg.AsyncConnection.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert conninfo_to_dict(got[0]) == want
diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py
new file mode 100644
index 0000000..e2c2c01
--- /dev/null
+++ b/tests/test_conninfo.py
@@ -0,0 +1,450 @@
+import socket
+import asyncio
+import datetime as dt
+
+import pytest
+
+import psycopg
+from psycopg import ProgrammingError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg.conninfo import resolve_hostaddr_async
+from psycopg._encodings import pg2pyenc
+
+from .fix_crdb import crdb_encoding
+
+snowman = "\u2603"
+
+
+class MyString(str):
+ pass
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs, exp",
+ [
+ ("", {}, ""),
+ ("dbname=foo", {}, "dbname=foo"),
+ ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"),
+ ("dbname=sony", {"password": ""}, "dbname=sony password="),
+ ("dbname=foo", {"dbname": "bar"}, "dbname=bar"),
+ ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"),
+ ("", {"dbname": "foo"}, "dbname=foo"),
+ ("", {"dbname": "foo", "user": None}, "dbname=foo"),
+ ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"),
+ ("", {"dbname": "a'b"}, r"dbname='a\'b'"),
+ (f"dbname={snowman}", {}, f"dbname={snowman}"),
+ ("", {"dbname": snowman}, f"dbname={snowman}"),
+ (
+ "postgresql://host1/test",
+ {"host": "host2"},
+ "dbname=test host=host2",
+ ),
+ (MyString(""), {}, ""),
+ ],
+)
+def test_make_conninfo(conninfo, kwargs, exp):
+ out = make_conninfo(conninfo, **kwargs)
+ assert conninfo_to_dict(out) == conninfo_to_dict(exp)
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs",
+ [
+ ("hello", {}),
+ ("dbname=foo bar", {}),
+ ("foo=bar", {}),
+ ("dbname=foo", {"bar": "baz"}),
+ ("postgresql://tester:secret@/test?port=5433=x", {}),
+ (f"{snowman}={snowman}", {}),
+ ],
+)
+def test_make_conninfo_bad(conninfo, kwargs):
+ with pytest.raises(ProgrammingError):
+ make_conninfo(conninfo, **kwargs)
+
+
+@pytest.mark.parametrize(
+ "conninfo, exp",
+ [
+ ("", {}),
+ ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}),
+ ("dbname=sony password=", {"dbname": "sony", "password": ""}),
+ ("dbname='foo bar'", {"dbname": "foo bar"}),
+ ("dbname='a\"b'", {"dbname": 'a"b'}),
+ (r"dbname='a\'b'", {"dbname": "a'b"}),
+ (r"dbname='a\\b'", {"dbname": r"a\b"}),
+ (f"dbname={snowman}", {"dbname": snowman}),
+ (
+ "postgresql://tester:secret@/test?port=5433",
+ {
+ "user": "tester",
+ "password": "secret",
+ "dbname": "test",
+ "port": "5433",
+ },
+ ),
+ ],
+)
+def test_conninfo_to_dict(conninfo, exp):
+ assert conninfo_to_dict(conninfo) == exp
+
+
+def test_no_munging():
+ dsnin = "dbname=a host=b user=c password=d"
+ dsnout = make_conninfo(dsnin)
+ assert dsnin == dsnout
+
+
+class TestConnectionInfo:
+ @pytest.mark.parametrize(
+ "attr",
+ [("dbname", "db"), "host", "hostaddr", "user", "password", "options"],
+ )
+ def test_attrs(self, conn, attr):
+ if isinstance(attr, tuple):
+ info_attr, pgconn_attr = attr
+ else:
+ info_attr = pgconn_attr = attr
+
+ if info_attr == "hostaddr" and psycopg.pq.version() < 120000:
+ pytest.skip("hostaddr not supported on libpq < 12")
+
+ info_val = getattr(conn.info, info_attr)
+ pgconn_val = getattr(conn.pgconn, pgconn_attr).decode()
+ assert info_val == pgconn_val
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ getattr(conn.info, info_attr)
+
+ @pytest.mark.libpq("< 12")
+ def test_hostaddr_not_supported(self, conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.info.hostaddr
+
+ def test_port(self, conn):
+ assert conn.info.port == int(conn.pgconn.port.decode())
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.port
+
+ def test_get_params(self, conn, dsn):
+ info = conn.info.get_parameters()
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert info.get(k) == v
+ else:
+ assert k not in info
+
+ def test_dsn(self, conn, dsn):
+ dsn = conn.info.dsn
+ assert "password" not in dsn
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert f"{k}=" in dsn
+
+ def test_get_params_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name" not in conn.info.get_parameters()
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert conn.info.get_parameters()["application_name"] == "hello test"
+
+ def test_dsn_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name=" not in conn.info.dsn
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name='hello test'" in conn.info.dsn
+
+ def test_status(self, conn):
+ assert conn.info.status.name == "OK"
+ conn.close()
+ assert conn.info.status.name == "BAD"
+
+ def test_transaction_status(self, conn):
+ assert conn.info.transaction_status.name == "IDLE"
+ conn.close()
+ assert conn.info.transaction_status.name == "UNKNOWN"
+
+ @pytest.mark.pipeline
+ def test_pipeline_status(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+ with conn.pipeline():
+ assert conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "ON"
+
+ @pytest.mark.libpq("< 14")
+ def test_pipeline_status_no_pipeline(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+
+ def test_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.get_parameters()
+ assert info.get_parameters()["dbname"] == info.dbname
+
+ def test_dsn_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.dsn
+ assert f"dbname={info.dbname}" in info.dsn
+
+ def test_parameter_status(self, conn):
+ assert conn.info.parameter_status("nosuchparam") is None
+ tz = conn.info.parameter_status("TimeZone")
+ assert tz and isinstance(tz, str)
+ assert tz == conn.execute("show timezone").fetchone()[0]
+
+ @pytest.mark.crdb("skip")
+ def test_server_version(self, conn):
+ assert conn.info.server_version == conn.pgconn.server_version
+
+ def test_error_message(self, conn):
+ assert conn.info.error_message == ""
+ with pytest.raises(psycopg.ProgrammingError) as ex:
+ conn.execute("wat")
+
+ assert conn.info.error_message
+ assert str(ex.value) in conn.info.error_message
+ assert ex.value.diag.severity in conn.info.error_message
+
+ conn.close()
+ assert "NULL" in conn.info.error_message
+
+ @pytest.mark.crdb_skip("backend pid")
+ def test_backend_pid(self, conn):
+ assert conn.info.backend_pid
+ assert conn.info.backend_pid == conn.pgconn.backend_pid
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.backend_pid
+
+ def test_timezone(self, conn):
+ conn.execute("set timezone to 'Europe/Rome'")
+ tz = conn.info.timezone
+ assert isinstance(tz, dt.tzinfo)
+ offset = tz.utcoffset(dt.datetime(2000, 1, 1))
+ assert offset and offset.total_seconds() == 3600
+ offset = tz.utcoffset(dt.datetime(2000, 7, 1))
+ assert offset and offset.total_seconds() == 7200
+
+ @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones")
+ def test_timezone_warn(self, conn, caplog):
+ conn.execute("set timezone to 'FOOBAR0'")
+ assert len(caplog.records) == 0
+ tz = conn.info.timezone
+ assert tz == dt.timezone.utc
+ assert len(caplog.records) == 1
+ assert "FOOBAR0" in caplog.records[0].message
+
+ conn.info.timezone
+ assert len(caplog.records) == 1
+
+ conn.execute("set timezone to 'FOOBAAR0'")
+ assert len(caplog.records) == 1
+ conn.info.timezone
+ assert len(caplog.records) == 2
+ assert "FOOBAAR0" in caplog.records[1].message
+
+ def test_encoding(self, conn):
+ enc = conn.execute("show client_encoding").fetchone()[0]
+ assert conn.info.encoding == pg2pyenc(enc.encode())
+
+ @pytest.mark.crdb("skip", reason="encoding not normalized")
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ ("eucjp", "EUC_JP", "euc_jp"),
+ ("euc-jp", "EUC_JP", "euc_jp"),
+ ("latin9", "LATIN9", "iso8859-15"),
+ ],
+ )
+ def test_normalize_encoding(self, conn, enc, out, codec):
+ conn.execute("select set_config('client_encoding', %s, false)", [enc])
+ assert conn.info.parameter_status("client_encoding") == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ crdb_encoding("eucjp", "EUC_JP", "euc_jp"),
+ crdb_encoding("euc-jp", "EUC_JP", "euc_jp"),
+ ],
+ )
+ def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec):
+ monkeypatch.setenv("PGCLIENTENCODING", enc)
+ with conn_cls.connect(dsn) as conn:
+ clienc = conn.info.parameter_status("client_encoding")
+ assert clienc
+ if conn.info.vendor == "PostgreSQL":
+ assert clienc == out
+ else:
+ assert clienc.replace("-", "").replace("_", "").upper() == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.crdb_skip("encoding")
+ def test_set_encoding_unsupported(self, conn):
+ cur = conn.cursor()
+ cur.execute("set client_encoding to EUC_TW")
+ with pytest.raises(psycopg.NotSupportedError):
+ cur.execute("select 'x'")
+
+ def test_vendor(self, conn):
+ assert conn.info.vendor
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", "", None),
+ ("host='' user=bar", "host='' user=bar", None),
+ (
+ "host=127.0.0.1 user=bar",
+ "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ "host=foo.com port=5432",
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(
+ setpgenv, conninfo, want, env, fail_resolve
+):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ {},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ await resolve_hostaddr_async(params)
+
+
+@pytest.fixture
+async def fake_resolve(monkeypatch):
+ fake_hosts = {
+ "localhost": "127.0.0.1",
+ "foo.com": "1.1.1.1",
+ "qux.com": "2.2.2.2",
+ }
+
+ async def fake_getaddrinfo(host, port, **kwargs):
+ assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
+ try:
+ addr = fake_hosts[host]
+ except KeyError:
+ raise OSError(f"unknown test host: {host}")
+ else:
+ return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
+
+
+@pytest.fixture
+async def fail_resolve(monkeypatch):
+ async def fail_getaddrinfo(host, port, **kwargs):
+ pytest.fail(f"shouldn't try to resolve {host}")
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
diff --git a/tests/test_copy.py b/tests/test_copy.py
new file mode 100644
index 0000000..17cf2fc
--- /dev/null
+++ b/tests/test_copy.py
@@ -0,0 +1,889 @@
+import string
+import struct
+import hashlib
+from io import BytesIO, StringIO
+from random import choice, randrange
+from itertools import cycle
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter
+from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
+
+from .utils import eur, gc_collect, gc_count
+
+pytestmark = pytest.mark.crdb_skip("copy")
+
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
+sample_tabledef = "col1 serial primary key, col2 int, data text"
+
+sample_text = b"""\
+40010\t40020\thello
+40040\t\\N\tworld
+"""
+
+sample_binary_str = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
+]
+sample_binary = b"".join(sample_binary_rows)
+
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_read(conn, format):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ for row in want:
+ got = copy.read()
+ assert got == row
+ assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE
+
+ assert copy.read() == b""
+ assert copy.read() == b""
+
+ assert copy.read() == b""
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_iter(conn, format, row_factory):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ assert list(copy) == want
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_no_result(conn, format, row_factory):
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+ with pytest.raises(e.ProgrammingError):
+ cur.fetchone()
+
+
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+def test_read_rows(conn, format, typetype):
+ cur = conn.cursor()
+ with cur.copy(
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
+ ) as copy:
+ copy.set_types(["int4", "text", "float8[]"])
+ row = copy.read_row()
+ assert copy.read_row() is None
+
+ assert row == (10, "hello", [0.0, 1.0])
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ rows = list(copy.rows())
+
+ assert rows == sample_records
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_set_custom_type(conn, hstore):
+ command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+ cur = conn.cursor()
+
+ with cur.copy(command) as copy:
+ rows = list(copy.rows())
+
+ assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+ register_hstore(TypeInfo.fetch(conn, "hstore"), cur)
+ with cur.copy(command) as copy:
+ copy.set_types(["hstore"])
+ rows = list(copy.rows())
+
+ assert rows == [({"a": "1", "b": "2"},)]
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_allchars(conn, format):
+ cur = conn.cursor()
+ chars = list(map(chr, range(1, 256))) + [eur]
+ conn.execute("set client_encoding to utf8")
+ rows = []
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
+ with cur.copy(query) as copy:
+ copy.set_types(["text"])
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ assert len(row) == 1
+ rows.append(row[0])
+
+ assert rows == chars
+
+
+@pytest.mark.parametrize("format", Format)
+def test_read_row_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = []
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ rows.append(row)
+
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = list(copy.rows())
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_badntypes(conn, format, err):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types([0] * (len(sample_records[0]) + err))
+ with pytest.raises(e.ProgrammingError):
+ copy.read_row()
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_bad_result(conn):
+ conn.autocommit = True
+
+ cur = conn.cursor()
+
+ with pytest.raises(e.SyntaxError):
+ with cur.copy("wat"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("reset timezone"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("copy (select 1) to stdout; select 1") as copy:
+ list(copy)
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1; copy (select 1) to stdout"):
+ pass
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin (format binary)") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+def test_copy_big_size_block(conn, pytype):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.parametrize("format", Format)
+def test_subclass_adapter(conn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import ( # type: ignore[no-redef]
+ StrBinaryDumper as BaseDumper,
+ )
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
+ copy.write_row(("hello",))
+
+ rec = cur.execute("select data from copy_in").fetchone()
+ assert rec[0] == "hellohello"
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_error_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_error_with_copy_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_copy_out_error_with_copy_not_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_server_error(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ for block in copy:
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ conn.execute("set client_encoding to utf8")
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+def test_copy_in_format(conn):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, writer=FileWriter(file)) as copy:
+ for i in range(1, 256):
+ copy.write_row((i, chr(i)))
+
+ file.seek(0)
+ rows = file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_file_writer(conn, format, buffer):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, binary=format, writer=FileWriter(file)) as copy:
+ for record in sample_records:
+ copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
+@pytest.mark.slow
+def test_copy_from_to(conn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_copy_from_to_bytes(conn, pytype):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(pytype(block.encode()))
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+def test_copy_from_insane_size(conn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+
+def test_copy_rowcount(conn):
+ gen = DataGenerator(conn, nrecs=3, srec=10)
+ gen.ensure_table()
+
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(conn, nrecs=2, srec=10, offset=3)
+ with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
+def test_copy_query(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert cur._query.query == b"copy (select 1) to stdout"
+ assert not cur._query.params
+ list(copy)
+
+
+def test_cant_reenter(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ list(copy)
+
+ with pytest.raises(TypeError):
+ with copy:
+ list(copy)
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert "[ACTIVE]" in str(copy)
+ list(copy)
+
+ assert "[INTRANS]" in str(copy)
+
+
+def test_description(conn):
+ with conn.cursor() as cur:
+ with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy:
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+ list(copy.rows())
+
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_worker_life(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=QueuedLibpqDriver(cur)
+ ) as copy:
+ assert not copy.writer._worker
+ copy.write(globals()[buffer])
+ assert copy.writer._worker
+
+ assert not copy.writer._worker
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_worker_error_propagated(conn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = conn.cursor()
+ cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy wat from stdin", writer=QueuedLibpqDriver(cur)) as copy:
+ copy.write("a,b")
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_connection_writer(conn, format, buffer):
+ cur = conn.cursor()
+ writer = LibpqWriter(cur)
+
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
+def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ stmt = sql.SQL(
+ "copy (select {} from {} order by id) to stdout (format {})"
+ ).format(
+ sql.SQL(", ").join(faker.fields_names),
+ faker.table_name,
+ sql.SQL(fmt.name),
+ )
+
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+
+ if method == "read":
+ while True:
+ tmp = copy.read()
+ if not tmp:
+ break
+ elif method == "iter":
+ list(copy)
+ elif method == "row":
+ while True:
+ tmp = copy.read_row()
+ if tmp is None:
+ break
+ elif method == "rows":
+ list(copy.rows())
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin (format {})").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL(fmt.name),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("mode", ["row", "block", "binary"])
+def test_copy_table_across(conn_cls, dsn, faker, mode):
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2:
+ faker.table_name = sql.Identifier("copy_src")
+ conn1.execute(faker.drop_stmt)
+ conn1.execute(faker.create_stmt)
+ conn1.cursor().executemany(faker.insert_stmt, faker.records)
+
+ faker.table_name = sql.Identifier("copy_tgt")
+ conn2.execute(faker.drop_stmt)
+ conn2.execute(faker.create_stmt)
+
+ fmt = "(format binary)" if mode == "binary" else ""
+ with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1:
+ with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2:
+ if mode == "row":
+ for row in copy1.rows():
+ copy2.write_row(row)
+ else:
+ for data in copy1:
+ copy2.write(data)
+
+ recs = conn2.execute(faker.select_stmt).fetchall()
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+def py_to_raw(item, fmt):
+ """Convert from Python type to the expected result from the db"""
+ if fmt == Format.TEXT:
+ if isinstance(item, int):
+ return str(item)
+ else:
+ if isinstance(item, int):
+ # Assume int4
+ return struct.pack("!i", item)
+ elif isinstance(item, str):
+ return item.encode()
+ return item
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+ cur.execute(f"drop table if exists {name}")
+ cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ def ensure_table(self):
+ cur = self.conn.cursor()
+ ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ def assert_data(self):
+ cur = self.conn.cursor()
+ cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == cur.fetchone()
+
+ assert cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while True:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode()
+ m.update(block)
+ return m.hexdigest()
diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py
new file mode 100644
index 0000000..59389dd
--- /dev/null
+++ b/tests/test_copy_async.py
@@ -0,0 +1,892 @@
+import string
+import hashlib
+from io import BytesIO, StringIO
+from random import choice, randrange
+from itertools import cycle
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.copy import AsyncCopy
+from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter
+from psycopg.types import TypeInfo
+from psycopg.adapt import PyFormat
+from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
+
+from .utils import alist, eur, gc_collect, gc_count
+from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
+from .test_copy import sample_values, sample_records, sample_tabledef
+from .test_copy import py_to_raw, special_chars
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("copy"),
+]
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_read(aconn, format):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ for row in want:
+ got = await copy.read()
+ assert got == row
+ assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
+
+ assert await copy.read() == b""
+ assert await copy.read() == b""
+
+ assert await copy.read() == b""
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_iter(aconn, format, row_factory):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ rf = getattr(psycopg.rows, row_factory)
+ cur = aconn.cursor(row_factory=rf)
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ assert await alist(copy) == want
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_no_result(aconn, format, row_factory):
+ rf = getattr(psycopg.rows, row_factory)
+ cur = aconn.cursor(row_factory=rf)
+ async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+ with pytest.raises(e.ProgrammingError):
+ await cur.fetchone()
+
+
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+async def test_copy_out_param(aconn, ph, params):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert await alist(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+async def test_read_rows(aconn, format, typetype):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
+ ) as copy:
+ copy.set_types(["int4", "text", "float8[]"])
+ row = await copy.read_row()
+ assert (await copy.read_row()) is None
+
+ assert row == (10, "hello", [0.0, 1.0])
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_rows(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ copy.set_types("int4 int4 text".split())
+ rows = await alist(copy.rows())
+
+ assert rows == sample_records
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_set_custom_type(aconn, hstore):
+ command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+ cur = aconn.cursor()
+
+ async with cur.copy(command) as copy:
+ rows = await alist(copy.rows())
+
+ assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+ register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur)
+ async with cur.copy(command) as copy:
+ copy.set_types(["hstore"])
+ rows = await alist(copy.rows())
+
+ assert rows == [({"a": "1", "b": "2"},)]
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_allchars(aconn, format):
+ cur = aconn.cursor()
+ chars = list(map(chr, range(1, 256))) + [eur]
+ await aconn.execute("set client_encoding to utf8")
+ rows = []
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
+ async with cur.copy(query) as copy:
+ copy.set_types(["text"])
+ while True:
+ row = await copy.read_row()
+ if not row:
+ break
+ assert len(row) == 1
+ rows.append(row[0])
+
+ assert rows == chars
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_read_row_notypes(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ rows = []
+ while True:
+ row = await copy.read_row()
+ if not row:
+ break
+ rows.append(row)
+
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_rows_notypes(aconn, format):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ rows = await alist(copy.rows())
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", Format)
+async def test_copy_out_badntypes(aconn, format, err):
+ cur = aconn.cursor()
+ async with cur.copy(
+ f"copy ({sample_values}) to stdout (format {format.name})"
+ ) as copy:
+ copy.set_types([0] * (len(sample_records[0]) + err))
+ with pytest.raises(e.ProgrammingError):
+ await copy.read_row()
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_bad_result(aconn):
+ await aconn.set_autocommit(True)
+
+ cur = aconn.cursor()
+
+ with pytest.raises(e.SyntaxError):
+ async with cur.copy("wat"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("select 1"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("reset timezone"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("copy (select 1) to stdout; select 1") as copy:
+ await alist(copy)
+
+ with pytest.raises(e.ProgrammingError):
+ async with cur.copy("select 1; copy (select 1) to stdout"):
+ pass
+
+
+async def test_copy_in_str(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text.decode())
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ async with cur.copy("copy copy_in from stdin (format binary)") as copy:
+ await copy.write(sample_text.decode())
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+async def test_copy_big_size_record(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write_row([data])
+
+ await cur.execute("select data from copy_in limit 1")
+ assert await cur.fetchone() == (data,)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+async def test_copy_big_size_block(aconn, pytype):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write(copy_data)
+
+ await cur.execute("select data from copy_in limit 1")
+ assert await cur.fetchone() == (data,)
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_subclass_adapter(aconn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import ( # type: ignore[no-redef]
+ StrBinaryDumper as BaseDumper,
+ )
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ aconn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(
+ f"copy copy_in (data) from stdin (format {format.name})"
+ ) as copy:
+ await copy.write_row(("hello",))
+
+ await cur.execute("select data from copy_in")
+ rec = await cur.fetchone()
+ assert rec[0] == "hellohello"
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_error_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_error_with_copy_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_copy_out_error_with_copy_not_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy (select generate_series(1, 1000000)) to stdout"
+ ) as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_server_error(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ async with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ async for block in copy:
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_binary(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ await aconn.execute("set client_encoding to utf8")
+ async with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+async def test_copy_in_format(aconn):
+ file = BytesIO()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, chr(i)))
+
+ file.seek(0)
+ rows = file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_file_writer(aconn, format, buffer):
+ file = BytesIO()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy:
+ for record in sample_records:
+ await copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
+@pytest.mark.slow
+async def test_copy_from_to(aconn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+ f = BytesIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+async def test_copy_from_to_bytes(aconn, pytype):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(pytype(block.encode()))
+
+ await gen.assert_data()
+
+ f = BytesIO()
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+async def test_copy_from_insane_size(aconn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ await gen.ensure_table()
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+
+ await gen.assert_data()
+
+
+async def test_copy_rowcount(aconn):
+ gen = DataGenerator(aconn, nrecs=3, srec=10)
+ await gen.ensure_table()
+
+ cur = aconn.cursor()
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ await copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ async with cur.copy("copy copy_in to stdout") as copy:
+ async for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ async with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ await copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
+async def test_copy_query(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ assert cur._query.query == b"copy (select 1) to stdout"
+ assert not cur._query.params
+ await alist(copy)
+
+
+async def test_cant_reenter(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ await alist(copy)
+
+ with pytest.raises(TypeError):
+ async with copy:
+ await alist(copy)
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ assert "[ACTIVE]" in str(copy)
+ await alist(copy)
+
+ assert "[INTRANS]" in str(copy)
+
+
+async def test_description(aconn):
+ async with aconn.cursor() as cur:
+ async with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy:
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+ await alist(copy.rows())
+
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_worker_life(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})",
+ writer=AsyncQueuedLibpqWriter(cur),
+ ) as copy:
+ assert not copy.writer._worker
+ await copy.write(globals()[buffer])
+ assert copy.writer._worker
+
+ assert not copy.writer._worker
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_worker_error_propagated(aconn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = aconn.cursor()
+ await cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy wat from stdin", writer=AsyncQueuedLibpqWriter(cur)
+ ) as copy:
+ await copy.write("a,b")
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_connection_writer(aconn, format, buffer):
+ cur = aconn.cursor()
+ writer = AsyncLibpqWriter(cur)
+
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
+async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+
+ stmt = sql.SQL(
+ "copy (select {} from {} order by id) to stdout (format {})"
+ ).format(
+ sql.SQL(", ").join(faker.fields_names),
+ faker.table_name,
+ sql.SQL(fmt.name),
+ )
+
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+
+ if method == "read":
+ while True:
+ tmp = await copy.read()
+ if not tmp:
+ break
+ elif method == "iter":
+ await alist(copy)
+ elif method == "row":
+ while True:
+ tmp = await copy.read_row()
+ if tmp is None:
+ break
+ elif method == "rows":
+ await alist(copy.rows())
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin (format {})").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL(fmt.name),
+ )
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ await copy.write_row(row)
+
+ await cur.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("mode", ["row", "block", "binary"])
+async def test_copy_table_across(aconn_cls, dsn, faker, mode):
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ connect = aconn_cls.connect
+ async with await connect(dsn) as conn1, await connect(dsn) as conn2:
+ faker.table_name = sql.Identifier("copy_src")
+ await conn1.execute(faker.drop_stmt)
+ await conn1.execute(faker.create_stmt)
+ await conn1.cursor().executemany(faker.insert_stmt, faker.records)
+
+ faker.table_name = sql.Identifier("copy_tgt")
+ await conn2.execute(faker.drop_stmt)
+ await conn2.execute(faker.create_stmt)
+
+ fmt = "(format binary)" if mode == "binary" else ""
+ async with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1:
+ async with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2:
+ if mode == "row":
+ async for row in copy1.rows():
+ await copy2.write_row(row)
+ else:
+ async for data in copy1:
+ await copy2.write(data)
+
+ cur = await conn2.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+async def ensure_table(cur, tabledef, name="copy_in"):
+ await cur.execute(f"drop table if exists {name}")
+ await cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ async def ensure_table(self):
+ cur = self.conn.cursor()
+ await ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ async def assert_data(self):
+ cur = self.conn.cursor()
+ await cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == await cur.fetchone()
+
+ assert await cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while True:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode()
+ m.update(block)
+ return m.hexdigest()
+
+
+class AsyncFileWriter(AsyncWriter):
+ def __init__(self, file):
+ self.file = file
+
+ async def write(self, data):
+ self.file.write(data)
diff --git a/tests/test_cursor.py b/tests/test_cursor.py
new file mode 100644
index 0000000..a667f4f
--- /dev/null
+++ b/tests/test_cursor.py
@@ -0,0 +1,942 @@
+import pickle
+import weakref
+import datetime as dt
+from typing import List, Union
+from contextlib import closing
+
+import pytest
+
+import psycopg
+from psycopg import pq, sql, rows
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+from psycopg.rows import RowMaker
+
+from .utils import gc_collect, gc_count
+from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
+
+
+def test_init(conn):
+ cur = psycopg.Cursor(conn)
+ cur.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ conn.row_factory = rows.dict_row
+ cur = psycopg.Cursor(conn)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_init_factory(conn):
+ cur = psycopg.Cursor(conn, row_factory=rows.dict_row)
+ cur.execute("select 1 as a")
+ assert cur.fetchone() == {"a": 1}
+
+
+def test_close(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.execute("select 'foo'")
+
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ cur.fetchall()
+
+
+def test_context(conn):
+ with conn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+def test_weakref(conn):
+ cur = conn.cursor()
+ w = weakref.ref(cur)
+ cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_statusmessage(conn):
+ cur = conn.cursor()
+ assert cur.statusmessage is None
+
+ cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+def test_execute_many_results(conn):
+ cur = conn.cursor()
+ assert cur.nextset() is None
+
+ rv = cur.execute("select 'foo'; select generate_series(1,3)")
+ assert rv is cur
+ assert cur.fetchall() == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.nextset() is None
+
+ cur.close()
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+ cur = conn.cursor()
+ rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+def test_execute_type_change(conn):
+ # issue #112
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.execute(sql, (1,))
+ cur.execute(sql, (100_000,))
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+def test_executemany_type_change(conn):
+ conn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = conn.cursor()
+ cur.executemany(sql, [(1,), (100_000,)])
+ cur.execute("select num from bug_112 order by num")
+ assert cur.fetchall() == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+def test_execute_copy(conn, query):
+ cur = conn.cursor()
+ cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute(query)
+
+
+def test_fetchone(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = cur.fetchone()
+ assert row == (1, "foo", None)
+ row = cur.fetchone()
+ assert row is None
+
+
+def test_binary_cursor_execute(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None])
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+def test_execute_binary(conn):
+ cur = conn.cursor()
+ cur.execute("select %s, %s", [1, None], binary=True)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s, %s", [1, None], binary=False)
+ assert cur.fetchone() == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+def test_query_encode(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("select '\u20ac'").fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+def test_query_badenc(conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+def test_executemany_no_data(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+def test_executemany_rowcount(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+def test_executemany_returning(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_returning_discard(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+def test_executemany_no_result(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+def test_executemany_rowcount_no_hit(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_executemany_null_first(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table testmany (a bigint, b bigint)")
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+def test_rowcount(conn):
+ cur = conn.cursor()
+
+ cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ cur.execute("show timezone")
+ assert cur.rowcount == 1
+
+ cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+
+def test_rownumber(conn):
+ cur = conn.cursor()
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+def test_rownumber_none(conn, query):
+ cur = conn.cursor()
+ cur.execute(query)
+ assert cur.rownumber is None
+
+
+def test_rownumber_mixed(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
+def test_iter(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ assert list(cur) == [(1,), (2,), (3,)]
+
+
+def test_iter_stop(conn):
+ cur = conn.cursor()
+ cur.execute("select generate_series(1, 3)")
+ for rec in cur:
+ assert rec == (1,)
+ break
+
+ for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert cur.fetchone() == (3,)
+ assert list(cur) == []
+
+
+def test_row_factory(conn):
+ cur = conn.cursor(row_factory=my_row_factory)
+
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+ cur.execute("select 'foo' as bar")
+ (r,) = cur.fetchone()
+ assert r == "FOObar"
+
+ cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert cur.fetchall() == [["Yy", "Zz"]]
+
+ cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"y": "y", "z": "z"}
+
+
+def test_row_factory_none(conn):
+ cur = conn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ r = cur.execute("select 1 as a, 2 as b").fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+def test_bad_row_factory(conn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = conn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = conn.cursor(row_factory=broken_maker)
+ cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ cur.fetchone()
+
+
+def test_scroll(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.scroll(0)
+
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(-1)
+ assert cur.fetchone() == (8,)
+ cur.scroll(-2)
+ assert cur.fetchone() == (7,)
+ cur.scroll(2, mode="absolute")
+ assert cur.fetchone() == (2,)
+
+ # on the boundary
+ cur.scroll(0, mode="absolute")
+ assert cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ cur.scroll(-1, mode="absolute")
+
+ cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(-1)
+
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ cur.scroll(10, mode="absolute")
+
+ cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ cur.scroll(1, "wat")
+
+
+def test_query_params_execute(conn):
+ cur = conn.cursor()
+ assert cur._query is None
+
+ cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select $1, $2::text"
+ assert cur._query.params == [b"1", None]
+
+ cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select $1::int"
+ assert cur._query.params == [b"wat"]
+
+
+def test_query_params_executemany(conn):
+ cur = conn.cursor()
+
+ cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select $1, $2"
+ assert cur._query.params == [b"3", b"4"]
+
+
+def test_stream(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+def test_stream_sql(conn):
+ cur = conn.cursor()
+ recs = list(
+ cur.stream(
+ sql.SQL(
+ "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+ ).format(2)
+ )
+ )
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+def test_stream_row_factory(conn):
+ cur = conn.cursor(row_factory=rows.dict_row)
+ it = iter(cur.stream("select generate_series(1,2) as a"))
+ assert next(it)["a"] == 1
+ cur.row_factory = rows.namedtuple_row
+ assert next(it).a == 2
+
+
+def test_stream_no_row(conn):
+ cur = conn.cursor()
+ recs = list(cur.stream("select generate_series(2,1) as a"))
+ assert recs == []
+
+
+@pytest.mark.crdb_skip("no col query")
+def test_stream_no_col(conn):
+ cur = conn.cursor()
+ recs = list(cur.stream("select"))
+ assert recs == [()]
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_stream_badq ()",
+ "copy (select 1) to stdout",
+ "wat?",
+ ],
+)
+def test_stream_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream(query):
+ pass
+
+
+def test_stream_error_tx(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream("wat"):
+ pass
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_stream_error_notx(conn):
+ conn.autocommit = True
+ cur = conn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ for rec in cur.stream("wat"):
+ pass
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+
+
+def test_stream_error_python_to_consume(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with closing(cur.stream("select generate_series(1, 10000)")) as gen:
+ for rec in gen:
+ 1 / 0
+ assert conn.info.transaction_status in (
+ conn.TransactionStatus.INTRANS,
+ conn.TransactionStatus.INERROR,
+ )
+
+
+def test_stream_error_python_consumed(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select 1")
+ for rec in gen:
+ 1 / 0
+ gen.close()
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_stream_close(conn):
+ cur = conn.cursor()
+ with pytest.raises(psycopg.OperationalError):
+ for rec in cur.stream("select generate_series(1, 3)"):
+ if rec[0] == 1:
+ conn.close()
+ else:
+ assert False
+
+ assert conn.closed
+
+
+def test_stream_binary_cursor(conn):
+ cur = conn.cursor(binary=True)
+ recs = []
+ for rec in cur.stream("select x::int4 from generate_series(1, 2) x"):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+def test_stream_execute_binary(conn):
+ cur = conn.cursor()
+ recs = []
+ for rec in cur.stream("select x::int4 from generate_series(1, 2) x", binary=True):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+def test_stream_binary_cursor_text_override(conn):
+ cur = conn.cursor(binary=True)
+ recs = []
+ for rec in cur.stream("select generate_series(1, 2)", binary=False):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode()
+
+ assert recs == [(1,), (2,)]
+
+
+class TestColumn:
+ def test_description_attribs(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ assert len(curs.description) == 3
+ for c in curs.description:
+ len(c) == 7 # DBAPI happy
+ for i, a in enumerate(
+ """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ ):
+ assert c[i] == getattr(c, a)
+
+ # Won't fill them up
+ assert c.null_ok is None
+
+ c = curs.description[0]
+ assert c.name == "pi"
+ assert c.type_code == builtins["numeric"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision == 10
+ assert c.scale == 2
+
+ c = curs.description[1]
+ assert c.name == "hi"
+ assert c.type_code == builtins["text"].oid
+ assert c.display_size is None
+ assert c.internal_size is None
+ assert c.precision is None
+ assert c.scale is None
+
+ c = curs.description[2]
+ assert c.name == "now"
+ assert c.type_code == builtins["date"].oid
+ assert c.display_size is None
+ if is_crdb(conn):
+ assert c.internal_size == 16
+ else:
+ assert c.internal_size == 4
+ assert c.precision is None
+ assert c.scale is None
+
+ def test_description_slice(self, conn):
+ curs = conn.cursor()
+ curs.execute("select 1::int as a")
+ curs.description[0][0:2] == ("a", 23)
+
+ @pytest.mark.parametrize(
+ "type, precision, scale, dsize, isize",
+ [
+ ("text", None, None, None, None),
+ ("varchar", None, None, None, None),
+ ("varchar(42)", None, None, 42, None),
+ ("int4", None, None, None, 4),
+ ("numeric", None, None, None, None),
+ ("numeric(10)", 10, 0, None, None),
+ ("numeric(10, 3)", 10, 3, None, None),
+ ("time", None, None, None, 8),
+ crdb_time_precision("time(4)", 4, None, None, 8),
+ crdb_time_precision("time(10)", 6, None, None, 8),
+ ],
+ )
+ def test_details(self, conn, type, precision, scale, dsize, isize):
+ cur = conn.cursor()
+ cur.execute(f"select null::{type}")
+ col = cur.description[0]
+ repr(col)
+ assert col.precision == precision
+ assert col.scale == scale
+ assert col.display_size == dsize
+ assert col.internal_size == isize
+
+ def test_pickle(self, conn):
+ curs = conn.cursor()
+ curs.execute(
+ """select
+ 3.14::decimal(10,2) as pi,
+ 'hello'::text as hi,
+ '2010-02-18'::date as now
+ """
+ )
+ description = curs.description
+ pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+ unpickled = pickle.loads(pickled)
+ assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]
+
+ @pytest.mark.crdb_skip("no col query")
+ def test_no_col_query(self, conn):
+ cur = conn.execute("select")
+ assert cur.description == []
+ assert cur.fetchall() == [()]
+
+ def test_description_closed_connection(self, conn):
+ # If we have reasons to break this test we will (e.g. we really need
+ # the connection). In #172 it fails just by accident.
+ cur = conn.execute("select 1::int4 as foo")
+ conn.close()
+ assert len(cur.description) == 1
+ col = cur.description[0]
+ assert col.name == "foo"
+ assert col.type_code == 23
+
+ def test_name_not_a_name(self, conn):
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "foo-bar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_name_encode(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ cur = conn.cursor()
+ (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone()
+ assert res == "x"
+ assert cur.description[0].name == "\u20ac"
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ assert "psycopg.Cursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+ faker.format = fmt
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ def work():
+ with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True):
+ with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ cur.fetchall()
+ elif fetch == "iter":
+ for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def my_row_factory(
+ cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]]
+) -> RowMaker[List[str]]:
+ if cursor.description is not None:
+ titles = [c.name for c in cursor.description]
+
+ def mkrow(values):
+ return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
+
+ return mkrow
+ else:
+ return rows.no_result
diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py
new file mode 100644
index 0000000..ac3fdeb
--- /dev/null
+++ b/tests/test_cursor_async.py
@@ -0,0 +1,802 @@
+import pytest
+import weakref
+import datetime as dt
+from typing import List
+
+import psycopg
+from psycopg import pq, sql, rows
+from psycopg.adapt import PyFormat
+
+from .utils import gc_collect, gc_count
+from .test_cursor import my_row_factory
+from .test_cursor import execmany, _execmany # noqa: F401
+from .fix_crdb import crdb_encoding
+
+execmany = execmany # avoid F811 underneath
+pytestmark = pytest.mark.asyncio
+
+
+async def test_init(aconn):
+ cur = psycopg.AsyncCursor(aconn)
+ await cur.execute("select 1")
+ assert (await cur.fetchone()) == (1,)
+
+ aconn.row_factory = rows.dict_row
+ cur = psycopg.AsyncCursor(aconn)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_init_factory(aconn):
+ cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row)
+ await cur.execute("select 1 as a")
+ assert (await cur.fetchone()) == {"a": 1}
+
+
+async def test_close(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.execute("select 'foo'")
+
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor()
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(psycopg.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_context(aconn):
+ async with aconn.cursor() as cur:
+ assert not cur.closed
+
+ assert cur.closed
+
+
+@pytest.mark.slow
+async def test_weakref(aconn):
+ cur = aconn.cursor()
+ w = weakref.ref(cur)
+ await cur.close()
+ del cur
+ gc_collect()
+ assert w() is None
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_statusmessage(aconn):
+ cur = aconn.cursor()
+ assert cur.statusmessage is None
+
+ await cur.execute("select generate_series(1, 10)")
+ assert cur.statusmessage == "SELECT 10"
+
+ await cur.execute("create table statusmessage ()")
+ assert cur.statusmessage == "CREATE TABLE"
+
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute("wat")
+ assert cur.statusmessage is None
+
+
+async def test_execute_many_results(aconn):
+ cur = aconn.cursor()
+ assert cur.nextset() is None
+
+ rv = await cur.execute("select 'foo'; select generate_series(1,3)")
+ assert rv is cur
+ assert (await cur.fetchall()) == [("foo",)]
+ assert cur.rowcount == 1
+ assert cur.nextset()
+ assert (await cur.fetchall()) == [(1,), (2,), (3,)]
+ assert cur.rowcount == 3
+ assert cur.nextset() is None
+
+ await cur.close()
+ assert cur.nextset() is None
+
+
+async def test_execute_sequence(aconn):
+ cur = aconn.cursor()
+ rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.pgresult.get_value(0, 1) == b"foo"
+ assert cur.pgresult.get_value(0, 2) is None
+ assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+async def test_execute_empty_query(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_execute_type_change(aconn):
+ # issue #112
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.execute(sql, (1,))
+ await cur.execute(sql, (100_000,))
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+async def test_executemany_type_change(aconn):
+ await aconn.execute("create table bug_112 (num integer)")
+ sql = "insert into bug_112 (num) values (%s)"
+ cur = aconn.cursor()
+ await cur.executemany(sql, [(1,), (100_000,)])
+ await cur.execute("select num from bug_112 order by num")
+ assert (await cur.fetchall()) == [(1,), (100_000,)]
+
+
+@pytest.mark.parametrize(
+ "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
+)
+async def test_execute_copy(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute("create table testcopy (id int)")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.execute(query)
+
+
+async def test_fetchone(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
+ assert cur.pgresult.fformat(0) == 0
+
+ row = await cur.fetchone()
+ assert row == (1, "foo", None)
+ row = await cur.fetchone()
+ assert row is None
+
+
+async def test_binary_cursor_execute(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None])
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+async def test_execute_binary(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s, %s", [1, None], binary=True)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x01"
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ await cur.execute("select %s, %s", [1, None], binary=False)
+ assert (await cur.fetchone()) == (1, None)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+
+
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+async def test_query_encode(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ await cur.execute("select '\u20ac'")
+ (res,) = await cur.fetchone()
+ assert res == "\u20ac"
+
+
+@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
+async def test_query_badenc(aconn, encoding):
+ await aconn.execute(f"set client_encoding to {encoding}")
+ cur = aconn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ await cur.execute("select '\u20ac'")
+
+
+async def test_executemany(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(10, "hello"), (20, "world")]
+
+
+async def test_executemany_name(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ await cur.execute("select num, data from execmany order by 1")
+ rv = await cur.fetchall()
+ assert rv == [(11, "hello"), (21, "world")]
+
+
+async def test_executemany_no_data(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
+ assert cur.rowcount == 0
+
+
+async def test_executemany_rowcount(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+async def test_executemany_returning(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_returning_discard(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s) returning num",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_result(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.statusmessage.startswith("INSERT")
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.fetchone()
+ pgresult = cur.pgresult
+ assert cur.nextset()
+ assert cur.statusmessage.startswith("INSERT")
+ assert pgresult is not cur.pgresult
+ assert cur.nextset() is None
+
+
+async def test_executemany_rowcount_no_hit(aconn, execmany):
+ cur = aconn.cursor()
+ await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)])
+ assert cur.rowcount == 0
+ await cur.executemany("delete from execmany where id = %s", [])
+ assert cur.rowcount == 0
+ await cur.executemany(
+ "delete from execmany where id = %s returning num", [(-1,), (-2,)]
+ )
+ assert cur.rowcount == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+async def test_executemany_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.executemany(query, [(10, "hello"), (20, "world")])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+async def test_executemany_null_first(aconn, fmt_in):
+ cur = aconn.cursor()
+ await cur.execute("create table testmany (a bigint, b bigint)")
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, None], [3, 4]],
+ )
+ with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
+ await cur.executemany(
+ f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ [[1, ""], [3, 4]],
+ )
+
+
+async def test_rowcount(aconn):
+ cur = aconn.cursor()
+
+ await cur.execute("select 1 from generate_series(1, 0)")
+ assert cur.rowcount == 0
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rowcount == 42
+
+ await cur.execute("show timezone")
+ assert cur.rowcount == 1
+
+ await cur.execute("create table test_rowcount_notuples (id int primary key)")
+ assert cur.rowcount == -1
+
+ await cur.execute(
+ "insert into test_rowcount_notuples select generate_series(1, 42)"
+ )
+ assert cur.rowcount == 42
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor()
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ rns: List[int] = []
+ async for i in cur:
+ assert cur.rownumber
+ rns.append(cur.rownumber)
+ if len(rns) >= 3:
+ break
+ assert rns == [13, 14, 15]
+ assert len(await cur.fetchall()) == 42 - rns[-1]
+ assert cur.rownumber == 42
+
+
+@pytest.mark.parametrize("query", ["", "set timezone to utc"])
+async def test_rownumber_none(aconn, query):
+ cur = aconn.cursor()
+ await cur.execute(query)
+ assert cur.rownumber is None
+
+
+async def test_rownumber_mixed(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+select x from generate_series(1, 3) x;
+set timezone to utc;
+select x from generate_series(4, 6) x;
+"""
+ )
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (1,)
+ assert cur.rownumber == 1
+ assert await cur.fetchone() == (2,)
+ assert cur.rownumber == 2
+ cur.nextset()
+ assert cur.rownumber is None
+ cur.nextset()
+ assert cur.rownumber == 0
+ assert await cur.fetchone() == (4,)
+ assert cur.rownumber == 1
+
+
+async def test_iter(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ res = []
+ async for rec in cur:
+ res.append(rec)
+ assert res == [(1,), (2,), (3,)]
+
+
+async def test_iter_stop(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select generate_series(1, 3)")
+ async for rec in cur:
+ assert rec == (1,)
+ break
+
+ async for rec in cur:
+ assert rec == (2,)
+ break
+
+ assert (await cur.fetchone()) == (3,)
+ async for rec in cur:
+ assert False
+
+
+async def test_row_factory(aconn):
+ cur = aconn.cursor(row_factory=my_row_factory)
+ await cur.execute("select 'foo' as bar")
+ (r,) = await cur.fetchone()
+ assert r == "FOObar"
+
+ await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
+ assert await cur.fetchall() == [["Xx"]]
+ assert cur.nextset()
+ assert await cur.fetchall() == [["Yy", "Zz"]]
+
+ await cur.scroll(-1)
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"y": "y", "z": "z"}
+
+
+async def test_row_factory_none(aconn):
+ cur = aconn.cursor(row_factory=None)
+ assert cur.row_factory is rows.tuple_row
+ await cur.execute("select 1 as a, 2 as b")
+ r = await cur.fetchone()
+ assert type(r) is tuple
+ assert r == (1, 2)
+
+
+async def test_bad_row_factory(aconn):
+ def broken_factory(cur):
+ 1 / 0
+
+ cur = aconn.cursor(row_factory=broken_factory)
+ with pytest.raises(ZeroDivisionError):
+ await cur.execute("select 1")
+
+ def broken_maker(cur):
+ def make_row(seq):
+ 1 / 0
+
+ return make_row
+
+ cur = aconn.cursor(row_factory=broken_maker)
+ await cur.execute("select 1")
+ with pytest.raises(ZeroDivisionError):
+ await cur.fetchone()
+
+
+async def test_scroll(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.scroll(0)
+
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-1)
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(-2)
+ assert await cur.fetchone() == (7,)
+ await cur.scroll(2, mode="absolute")
+ assert await cur.fetchone() == (2,)
+
+ # on the boundary
+ await cur.scroll(0, mode="absolute")
+ assert await cur.fetchone() == (0,)
+ with pytest.raises(IndexError):
+ await cur.scroll(-1, mode="absolute")
+
+ await cur.scroll(0, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(-1)
+
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+ with pytest.raises(IndexError):
+ await cur.scroll(10, mode="absolute")
+
+ await cur.scroll(9, mode="absolute")
+ with pytest.raises(IndexError):
+ await cur.scroll(1)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(1, "wat")
+
+
+async def test_query_params_execute(aconn):
+ cur = aconn.cursor()
+ assert cur._query is None
+
+ await cur.execute("select %t, %s::text", [1, None])
+ assert cur._query is not None
+ assert cur._query.query == b"select $1, $2::text"
+ assert cur._query.params == [b"1", None]
+
+ await cur.execute("select 1")
+ assert cur._query.query == b"select 1"
+ assert not cur._query.params
+
+ with pytest.raises(psycopg.DataError):
+ await cur.execute("select %t::int", ["wat"])
+
+ assert cur._query.query == b"select $1::int"
+ assert cur._query.params == [b"wat"]
+
+
+async def test_query_params_executemany(aconn):
+ cur = aconn.cursor()
+
+ await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
+ assert cur._query.query == b"select $1, $2"
+ assert cur._query.params == [b"3", b"4"]
+
+
+async def test_stream(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select i, '2021-01-01'::date + i from generate_series(1, %s) as i",
+ [2],
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_stream_sql(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ sql.SQL(
+ "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
+ ).format(2)
+ ):
+ recs.append(rec)
+
+ assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]
+
+
+async def test_stream_row_factory(aconn):
+ cur = aconn.cursor(row_factory=rows.dict_row)
+ ait = cur.stream("select generate_series(1,2) as a")
+ assert (await ait.__anext__())["a"] == 1
+ cur.row_factory = rows.namedtuple_row
+ assert (await ait.__anext__()).a == 2
+
+
+async def test_stream_no_row(aconn):
+ cur = aconn.cursor()
+ recs = [rec async for rec in cur.stream("select generate_series(2,1) as a")]
+ assert recs == []
+
+
+@pytest.mark.crdb_skip("no col query")
+async def test_stream_no_col(aconn):
+ cur = aconn.cursor()
+ recs = [rec async for rec in cur.stream("select")]
+ assert recs == [()]
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_stream_badq ()",
+ "copy (select 1) to stdout",
+ "wat?",
+ ],
+)
+async def test_stream_badquery(aconn, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream(query):
+ pass
+
+
+async def test_stream_error_tx(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream("wat"):
+ pass
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_stream_error_notx(aconn):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.ProgrammingError):
+ async for rec in cur.stream("wat"):
+ pass
+ assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE
+
+
+async def test_stream_error_python_to_consume(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select generate_series(1, 10000)")
+ async for rec in gen:
+ 1 / 0
+
+ await gen.aclose()
+ assert aconn.info.transaction_status in (
+ aconn.TransactionStatus.INTRANS,
+ aconn.TransactionStatus.INERROR,
+ )
+
+
+async def test_stream_error_python_consumed(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ gen = cur.stream("select 1")
+ async for rec in gen:
+ 1 / 0
+
+ await gen.aclose()
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_stream_close(aconn):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ with pytest.raises(psycopg.OperationalError):
+ async for rec in cur.stream("select generate_series(1, 3)"):
+ if rec[0] == 1:
+ await aconn.close()
+ else:
+ assert False
+
+ assert aconn.closed
+
+
+async def test_stream_binary_cursor(aconn):
+ cur = aconn.cursor(binary=True)
+ recs = []
+ async for rec in cur.stream("select x::int4 from generate_series(1, 2) x"):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_stream_execute_binary(aconn):
+ cur = aconn.cursor()
+ recs = []
+ async for rec in cur.stream(
+ "select x::int4 from generate_series(1, 2) x", binary=True
+ ):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_stream_binary_cursor_text_override(aconn):
+ cur = aconn.cursor(binary=True)
+ recs = []
+ async for rec in cur.stream("select generate_series(1, 2)", binary=False):
+ recs.append(rec)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode()
+
+ assert recs == [(1,), (2,)]
+
+
+async def test_str(aconn):
+ cur = aconn.cursor()
+ assert "psycopg.AsyncCursor" in str(cur)
+ assert "[IDLE]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" in str(cur)
+ await cur.execute("select 1")
+ assert "[INTRANS]" in str(cur)
+ assert "[TUPLES_OK]" in str(cur)
+ assert "[closed]" not in str(cur)
+ assert "[no result]" not in str(cur)
+ await cur.close()
+ assert "[closed]" in str(cur)
+ assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+ faker.format = fmt
+ faker.choose_schema(ncols=5)
+ faker.make_records(10)
+ row_factory = getattr(rows, row_factory)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn, conn.transaction(
+ force_rollback=True
+ ):
+ async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+ async with faker.find_insert_problem_async(conn):
+ await cur.executemany(faker.insert_stmt, faker.records)
+ await cur.execute(faker.select_stmt)
+
+ if fetch == "one":
+ while True:
+ tmp = await cur.fetchone()
+ if tmp is None:
+ break
+ elif fetch == "many":
+ while True:
+ tmp = await cur.fetchmany(3)
+ if not tmp:
+ break
+ elif fetch == "all":
+ await cur.fetchall()
+ elif fetch == "iter":
+ async for rec in cur:
+ pass
+
+ n = []
+ gc_collect()
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
diff --git a/tests/test_dns.py b/tests/test_dns.py
new file mode 100644
index 0000000..f50092f
--- /dev/null
+++ b/tests/test_dns.py
@@ -0,0 +1,27 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+pytestmark = [pytest.mark.dns]
+
+
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_warning(recwarn):
+ import_dnspython()
+ conninfo = "dbname=foo"
+ params = conninfo_to_dict(conninfo)
+ params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined]
+ params
+ )
+ assert conninfo_to_dict(conninfo) == params
+ assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
+
+
+def import_dnspython():
+ try:
+ import dns.rdtypes.IN.A # noqa: F401
+ except ImportError:
+ pytest.skip("dnspython package not available")
+
+ import psycopg._dns # noqa: F401
diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py
new file mode 100644
index 0000000..15b3706
--- /dev/null
+++ b/tests/test_dns_srv.py
@@ -0,0 +1,149 @@
+from typing import List, Union
+
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_dns import import_dnspython
+
+pytestmark = [pytest.mark.dns]
+
+samples_ok = [
+ ("", "", None),
+ ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None),
+ ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}),
+ (
+ "host=foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.bar.com",
+ "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com"
+ " port=5432,5432,5433,5432",
+ None,
+ ),
+ (
+ "host=service.foo.com port=srv",
+ "host=service.example.com port=15432",
+ None,
+ ),
+ # No resolution
+ (
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+]
+
+
+@pytest.mark.flakey("random weight order, might cause wrong order")
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+def test_srv(conninfo, want, env, fake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+async def test_srv_async(conninfo, want, env, afake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = await (
+ psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]
+ )
+ assert conninfo_to_dict(want) == params
+
+
+samples_bad = [
+ ("host=_pg._tcp.dot.com", None),
+ ("host=_pg._tcp.foo.com port=1,2", None),
+]
+
+
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+def test_srv_bad(conninfo, env, fake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ await psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]
+
+
+@pytest.fixture
+def fake_srv(monkeypatch):
+ f = get_fake_srv_function(monkeypatch)
+ monkeypatch.setattr(
+ psycopg._dns.resolver, # type: ignore[attr-defined]
+ "resolve",
+ f,
+ )
+
+
+@pytest.fixture
+def afake_srv(monkeypatch):
+ f = get_fake_srv_function(monkeypatch)
+
+ async def af(qname, rdtype):
+ return f(qname, rdtype)
+
+ monkeypatch.setattr(
+ psycopg._dns.async_resolver, # type: ignore[attr-defined]
+ "resolve",
+ af,
+ )
+
+
+def get_fake_srv_function(monkeypatch):
+ import_dnspython()
+
+ from dns.rdtypes.IN.A import A
+ from dns.rdtypes.IN.SRV import SRV
+ from dns.exception import DNSException
+
+ fake_hosts = {
+ ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."],
+ ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."],
+ ("_pg._tcp.bar.com", "SRV"): [
+ "1 0 5432 db2.example.com.",
+ "1 255 5433 db3.example.com.",
+ "0 0 5432 db1.example.com.",
+ "1 65535 5432 db4.example.com.",
+ ],
+ ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."],
+ }
+
+ def fake_srv_(qname, rdtype):
+ try:
+ ans = fake_hosts[qname, rdtype]
+ except KeyError:
+ raise DNSException(f"unknown test host: {qname} {rdtype}")
+ rv: List[Union[A, SRV]] = []
+
+ if rdtype == "A":
+ for entry in ans:
+ rv.append(A("IN", "A", entry))
+ else:
+ for entry in ans:
+ pri, w, port, target = entry.split()
+ rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target))
+
+ return rv
+
+ return fake_srv_
diff --git a/tests/test_encodings.py b/tests/test_encodings.py
new file mode 100644
index 0000000..113f0e3
--- /dev/null
+++ b/tests/test_encodings.py
@@ -0,0 +1,57 @@
+import codecs
+import pytest
+
+import psycopg
+from psycopg import _encodings as encodings
+
+
+def test_names_normalised():
+ for name in encodings._py_codecs.values():
+ assert codecs.lookup(name).name == name
+
+
+@pytest.mark.parametrize(
+ "pyenc, pgenc",
+ [
+ ("ascii", "SQL_ASCII"),
+ ("utf8", "UTF8"),
+ ("utf-8", "UTF8"),
+ ("uTf-8", "UTF8"),
+ ("latin9", "LATIN9"),
+ ("iso8859-15", "LATIN9"),
+ ],
+)
+def test_py2pg(pyenc, pgenc):
+ assert encodings.py2pgenc(pyenc) == pgenc.encode()
+
+
+@pytest.mark.parametrize(
+ "pyenc, pgenc",
+ [
+ ("ascii", "SQL_ASCII"),
+ ("utf-8", "UTF8"),
+ ("iso8859-15", "LATIN9"),
+ ],
+)
+def test_pg2py(pyenc, pgenc):
+ assert encodings.pg2pyenc(pgenc.encode()) == pyenc
+
+
+@pytest.mark.parametrize("pgenc", ["MULE_INTERNAL", "EUC_TW"])
+def test_pg2py_missing(pgenc):
+ with pytest.raises(psycopg.NotSupportedError):
+ encodings.pg2pyenc(pgenc.encode())
+
+
+@pytest.mark.parametrize(
+ "conninfo, pyenc",
+ [
+ ("", "utf-8"),
+ ("user=foo, dbname=bar", "utf-8"),
+ ("user=foo, dbname=bar, client_encoding=EUC_JP", "euc_jp"),
+ ("user=foo, dbname=bar, client_encoding=euc-jp", "euc_jp"),
+ ("user=foo, dbname=bar, client_encoding=WAT", "utf-8"),
+ ],
+)
+def test_conninfo_encoding(conninfo, pyenc):
+ assert encodings.conninfo_encoding(conninfo) == pyenc
diff --git a/tests/test_errors.py b/tests/test_errors.py
new file mode 100644
index 0000000..23ad314
--- /dev/null
+++ b/tests/test_errors.py
@@ -0,0 +1,309 @@
+import pickle
+from typing import List
+from weakref import ref
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+from .utils import eur, gc_collect
+from .fix_crdb import is_crdb
+
+
+@pytest.mark.crdb_skip("severity_nonlocalized")
+def test_error_diag(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ exc = excinfo.value
+ diag = exc.diag
+ assert diag.sqlstate == "42P01"
+ assert diag.severity_nonlocalized == "ERROR"
+
+
+def test_diag_all_attrs(pgconn):
+ res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+ diag = e.Diagnostic(res)
+ for d in pq.DiagnosticField:
+ val = getattr(diag, d.name.lower())
+ assert val is None or isinstance(val, str)
+
+
+def test_diag_right_attr(pgconn, monkeypatch):
+ res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR)
+ diag = e.Diagnostic(res)
+
+ to_check: pq.DiagnosticField
+ checked: List[pq.DiagnosticField] = []
+
+ def check_val(self, v):
+ nonlocal to_check
+ assert to_check == v
+ checked.append(v)
+ return None
+
+ monkeypatch.setattr(e.Diagnostic, "_error_message", check_val)
+
+ for to_check in pq.DiagnosticField:
+ getattr(diag, to_check.name.lower())
+
+ assert len(checked) == len(pq.DiagnosticField)
+
+
+def test_diag_attr_values(conn):
+ if is_crdb(conn):
+ conn.execute("set experimental_enable_temp_tables = 'on'")
+ conn.execute(
+ """
+ create temp table test_exc (
+ data int constraint chk_eq1 check (data = 1)
+ )"""
+ )
+ with pytest.raises(e.Error) as exc:
+ conn.execute("insert into test_exc values(2)")
+ diag = exc.value.diag
+ assert diag.sqlstate == "23514"
+ assert diag.constraint_name == "chk_eq1"
+ if not is_crdb(conn):
+ assert diag.table_name == "test_exc"
+ assert diag.schema_name and diag.schema_name[:7] == "pg_temp"
+ assert diag.severity_nonlocalized == "ERROR"
+
+
+@pytest.mark.crdb_skip("do")
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_diag_encoding(conn, enc):
+ msgs = []
+ conn.pgconn.exec_(b"set client_min_messages to notice")
+ conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary))
+ conn.execute(f"set client_encoding to {enc}")
+ cur = conn.cursor()
+ cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql")
+ assert msgs == [f"hello {eur}"]
+
+
+@pytest.mark.crdb_skip("do")
+@pytest.mark.parametrize("enc", ["utf8", "latin9"])
+def test_error_encoding(conn, enc):
+ with conn.transaction():
+ conn.execute(f"set client_encoding to {enc}")
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute(
+ """
+ do $$begin
+ execute format('insert into "%s" values (1)', chr(8364));
+ end$$ language plpgsql;
+ """
+ )
+
+ diag = excinfo.value.diag
+ assert diag.message_primary and f'"{eur}"' in diag.message_primary
+ assert diag.sqlstate == "42P01"
+
+
+def test_exception_class(conn):
+ cur = conn.cursor()
+
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select * from nonexist")
+
+ assert isinstance(excinfo.value, e.UndefinedTable)
+ assert isinstance(excinfo.value, conn.ProgrammingError)
+
+
+def test_exception_class_fallback(conn):
+ cur = conn.cursor()
+
+ x = e._sqlcodes.pop("42P01")
+ try:
+ with pytest.raises(e.Error) as excinfo:
+ cur.execute("select * from nonexist")
+ finally:
+ e._sqlcodes["42P01"] = x
+
+ assert type(excinfo.value) is conn.ProgrammingError
+
+
+def test_lookup():
+ assert e.lookup("42P01") is e.UndefinedTable
+ assert e.lookup("42p01") is e.UndefinedTable
+ assert e.lookup("UNDEFINED_TABLE") is e.UndefinedTable
+ assert e.lookup("undefined_table") is e.UndefinedTable
+
+ with pytest.raises(KeyError):
+ e.lookup("XXXXX")
+
+
+def test_error_sqlstate():
+ assert e.Error.sqlstate is None
+ assert e.ProgrammingError.sqlstate is None
+ assert e.UndefinedTable.sqlstate == "42P01"
+
+
+def test_error_pickle(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert isinstance(exc, e.UndefinedTable)
+ assert exc.diag.sqlstate == "42P01"
+
+
+def test_diag_pickle(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DatabaseError) as excinfo:
+ cur.execute("select 1 from wat")
+
+ diag1 = excinfo.value.diag
+ diag2 = pickle.loads(pickle.dumps(diag1))
+
+ assert isinstance(diag2, type(diag1))
+ for f in pq.DiagnosticField:
+ assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower())
+
+ assert diag2.sqlstate == "42P01"
+
+
+@pytest.mark.slow
+def test_diag_survives_cursor(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.Error) as exc:
+ cur.execute("select * from nosuchtable")
+
+ diag = exc.value.diag
+ del exc
+ w = ref(cur)
+ del cur
+ gc_collect()
+ assert w() is None
+ assert diag.sqlstate == "42P01"
+
+
+def test_diag_independent(conn):
+ conn.autocommit = True
+ cur = conn.cursor()
+
+ with pytest.raises(e.Error) as exc1:
+ cur.execute("l'acqua e' poca e 'a papera nun galleggia")
+
+ with pytest.raises(e.Error) as exc2:
+ cur.execute("select level from water where ducks > 1")
+
+ assert exc1.value.diag.sqlstate == "42601"
+ assert exc2.value.diag.sqlstate == "42P01"
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_diag_from_commit(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create temp table test_deferred (
+ data int primary key,
+ ref int references test_deferred (data)
+ deferrable initially deferred)
+ """
+ )
+ cur.execute("insert into test_deferred values (1,2)")
+ with pytest.raises(e.Error) as exc:
+ conn.commit()
+
+ assert exc.value.diag.sqlstate == "23503"
+
+
+@pytest.mark.asyncio
+@pytest.mark.crdb_skip("deferrable")
+async def test_diag_from_commit_async(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+ create temp table test_deferred (
+ data int primary key,
+ ref int references test_deferred (data)
+ deferrable initially deferred)
+ """
+ )
+ await cur.execute("insert into test_deferred values (1,2)")
+ with pytest.raises(e.Error) as exc:
+ await aconn.commit()
+
+ assert exc.value.diag.sqlstate == "23503"
+
+
+def test_query_context(conn):
+ with pytest.raises(e.Error) as exc:
+ conn.execute("select * from wat")
+
+ s = str(exc.value)
+ if not is_crdb(conn):
+ assert "from wat" in s, s
+ assert exc.value.diag.message_primary
+ assert exc.value.diag.message_primary in s
+ assert "ERROR" not in s
+ assert not s.endswith("\n")
+
+
+@pytest.mark.crdb_skip("do")
+def test_unknown_sqlstate(conn):
+ code = "PXX99"
+ with pytest.raises(KeyError):
+ e.lookup(code)
+
+ with pytest.raises(e.ProgrammingError) as excinfo:
+ conn.execute(
+ f"""
+ do $$begin
+ raise exception 'made up code' using errcode = '{code}';
+ end$$ language plpgsql
+ """
+ )
+ exc = excinfo.value
+ assert exc.diag.sqlstate == code
+ assert exc.sqlstate == code
+ # Survives pickling too
+ pexc = pickle.loads(pickle.dumps(exc))
+ assert pexc.sqlstate == code
+
+
+def test_pgconn_error(conn_cls):
+ with pytest.raises(psycopg.OperationalError) as excinfo:
+ conn_cls.connect("dbname=nosuchdb")
+
+ exc = excinfo.value
+ assert exc.pgconn
+ assert exc.pgconn.db == b"nosuchdb"
+
+
+def test_pgconn_error_pickle(conn_cls):
+ with pytest.raises(psycopg.OperationalError) as excinfo:
+ conn_cls.connect("dbname=nosuchdb")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert exc.pgconn is None
+
+
+def test_pgresult(conn):
+ with pytest.raises(e.DatabaseError) as excinfo:
+ conn.execute("select 1 from wat")
+
+ exc = excinfo.value
+ assert exc.pgresult
+ assert exc.pgresult.error_field(pq.DiagnosticField.SQLSTATE) == b"42P01"
+
+
+def test_pgresult_pickle(conn):
+ with pytest.raises(e.DatabaseError) as excinfo:
+ conn.execute("select 1 from wat")
+
+ exc = pickle.loads(pickle.dumps(excinfo.value))
+ assert exc.pgresult is None
+ assert exc.diag.sqlstate == "42P01"
+
+
+def test_blank_sqlstate(conn):
+ assert e.get_base_exception("") is e.DatabaseError
diff --git a/tests/test_generators.py b/tests/test_generators.py
new file mode 100644
index 0000000..8aba73f
--- /dev/null
+++ b/tests/test_generators.py
@@ -0,0 +1,156 @@
+from collections import deque
+from functools import partial
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import waiting
+from psycopg import pq
+
+
+@pytest.fixture
+def pipeline(pgconn):
+ nb, pgconn.nonblocking = pgconn.nonblocking, True
+ assert pgconn.nonblocking
+ pgconn.enter_pipeline_mode()
+ yield
+ if pgconn.pipeline_status:
+ pgconn.exit_pipeline_mode()
+ pgconn.nonblocking = nb
+
+
+def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses):
+ actual_statuses: List[pq.ExecStatus] = []
+ while len(actual_statuses) != len(expected_statuses):
+ if commands:
+ gen = generators.pipeline_communicate(pgconn, commands)
+ results = waiting.wait(gen, pgconn.socket)
+ for (result,) in results:
+ actual_statuses.append(result.status)
+ else:
+ gen = generators.fetch_many(pgconn)
+ results = waiting.wait(gen, pgconn.socket)
+ for result in results:
+ actual_statuses.append(result.status)
+
+ assert actual_statuses == expected_statuses
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, b"select 1", None),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, b"select 2", None),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_no_sync(pgconn, pipeline, generators):
+ numqueries = 10
+ commands = deque(
+ [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)]
+ * numqueries
+ + [pgconn.send_flush_request]
+ )
+ expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.fixture
+def pipeline_demo(pgconn):
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ yield "pg_pipeline"
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
+# TODOCRDB: 1 doesn't get rolled back. Open a ticket?
+@pytest.mark.pipeline
+@pytest.mark.crdb("skip", reason="pipeline aborted")
+def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators):
+ insert_sql = b"insert into pg_pipeline(itemno) values ($1)"
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, insert_sql, [b"1"]),
+ partial(pgconn.send_query_params, b"select no_such_function(1)", None),
+ partial(pgconn.send_query_params, insert_sql, [b"2"]),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, insert_sql, [b"3"]),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+ pgconn.exit_pipeline_mode()
+ res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno")
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"3"
+
+
+@pytest.fixture
+def pipeline_uniqviol(pgconn):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol("
+ b" id bigint primary key, idata bigint)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(b"BEGIN")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.prepare(
+ b"insertion",
+ b"insert into pg_pipeline_uniqviol values ($1, $2) returning id",
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ return "pg_pipeline_uniqviol"
+
+
+def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]),
+ partial(pgconn.send_query_params, b"commit", None),
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
diff --git a/tests/test_module.py b/tests/test_module.py
new file mode 100644
index 0000000..794ef0f
--- /dev/null
+++ b/tests/test_module.py
@@ -0,0 +1,57 @@
+import pytest
+
+from psycopg._cmodule import _psycopg
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want_conninfo",
+ [
+ ((), {}, ""),
+ (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
+ ((), {"port": 15432}, "port=15432"),
+ ((), {"user": "foo", "dbname": None}, "user=foo"),
+ ],
+)
+def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+ # Check the main args passing from psycopg.connect to the conn generator
+ # Details of the params manipulation are in test_conninfo.
+ import psycopg.connection
+
+ orig_connect = psycopg.connection.connect # type: ignore
+
+ got_conninfo = None
+
+ def mock_connect(conninfo):
+ nonlocal got_conninfo
+ got_conninfo = conninfo
+ return orig_connect(dsn)
+
+ monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
+
+ conn = psycopg.connect(*args, **kwargs)
+ assert got_conninfo == want_conninfo
+ conn.close()
+
+
+def test_version(mypy):
+ cp = mypy.run_on_source(
+ """\
+from psycopg import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
+
+
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
+def test_version_c(mypy):
+ # can be psycopg_c, psycopg_binary
+ cpackage = _psycopg.__name__.split(".")[0]
+
+ cp = mypy.run_on_source(
+ f"""\
+from {cpackage} import __version__
+assert __version__
+"""
+ )
+ assert not cp.stdout
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 0000000..56fe598
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,577 @@
+import logging
+import concurrent.futures
+from typing import Any
+from operator import attrgetter
+from itertools import groupby
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+pytestmark = [
+ pytest.mark.pipeline,
+ pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
+]
+
+pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
+
+
+def test_repr(conn):
+ with conn.pipeline() as p:
+ assert "psycopg.Pipeline" in repr(p)
+ assert "[IDLE, pipeline=ON]" in repr(p)
+
+ conn.close()
+ assert "[BAD]" in repr(p)
+
+
+def test_connection_closed(conn):
+ conn.close()
+ with pytest.raises(e.OperationalError):
+ with conn.pipeline():
+ pass
+
+
+def test_pipeline_status(conn: psycopg.Connection[Any]) -> None:
+ assert conn._pipeline is None
+ with conn.pipeline() as p:
+ assert conn._pipeline is p
+ assert p.status == pq.PipelineStatus.ON
+ assert p.status == pq.PipelineStatus.OFF
+ assert not conn._pipeline
+
+
+def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None:
+ with conn.pipeline() as p1:
+ with conn.pipeline() as p2:
+ assert p2 is p1
+ assert p1.status == pq.PipelineStatus.ON
+ assert p2 is p1
+ assert p2.status == pq.PipelineStatus.ON
+ assert conn._pipeline is None
+ assert p1.status == pq.PipelineStatus.OFF
+
+
+def test_pipeline_broken_conn_exit(conn: psycopg.Connection[Any]) -> None:
+ with pytest.raises(e.OperationalError):
+ with conn.pipeline():
+ conn.execute("select 1")
+ conn.close()
+ closed = True
+
+ assert closed
+
+
+def test_pipeline_exit_error_noclobber(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ with conn.pipeline():
+ conn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 1
+
+
+def test_pipeline_exit_error_noclobber_nested(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ with conn.pipeline():
+ with conn.pipeline():
+ conn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
+def test_pipeline_exit_sync_trace(conn, trace):
+ t = trace.trace(conn)
+ with conn.pipeline():
+ pass
+ conn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 1
+
+
+def test_pipeline_nested_sync_trace(conn, trace):
+ t = trace.trace(conn)
+ with conn.pipeline():
+ with conn.pipeline():
+ pass
+ conn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 2
+
+
+def test_cursor_stream(conn):
+ with conn.pipeline(), conn.cursor() as cur:
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.stream("select 1").__next__()
+
+
+def test_server_cursor(conn):
+ with conn.cursor(name="pipeline") as cur, conn.pipeline():
+ with pytest.raises(psycopg.NotSupportedError):
+ cur.execute("select 1")
+
+
+def test_cannot_insert_multiple_commands(conn):
+ with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)):
+ with conn.pipeline():
+ conn.execute("select 1; select 2")
+
+
+def test_copy(conn):
+ with conn.pipeline():
+ cur = conn.cursor()
+ with pytest.raises(e.NotSupportedError):
+ with cur.copy("copy (select 1) to stdout"):
+ pass
+
+
+def test_pipeline_processed_at_exit(conn):
+ with conn.cursor() as cur:
+ with conn.pipeline() as p:
+ cur.execute("select 1")
+
+ assert len(p.result_queue) == 1
+
+ assert cur.fetchone() == (1,)
+
+
+def test_pipeline_errors_processed_at_exit(conn):
+ conn.autocommit = True
+ with pytest.raises(e.UndefinedTable):
+ with conn.pipeline():
+ conn.execute("select * from nosuchtable")
+ conn.execute("create table voila ()")
+ cur = conn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = cur.fetchone()
+ assert count == 0
+
+
+def test_pipeline(conn):
+ with conn.pipeline() as p:
+ c1 = conn.cursor()
+ c2 = conn.cursor()
+ c1.execute("select 1")
+ c2.execute("select 2")
+
+ assert len(p.result_queue) == 2
+
+ (r1,) = c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = c2.fetchone()
+ assert r2 == 2
+
+
+def test_autocommit(conn):
+ conn.autocommit = True
+ with conn.pipeline(), conn.cursor() as c:
+ c.execute("select 1")
+
+ (r,) = c.fetchone()
+ assert r == 1
+
+
+def test_pipeline_aborted(conn):
+ conn.autocommit = True
+ with conn.pipeline() as p:
+ c1 = conn.execute("select 1")
+ with pytest.raises(e.UndefinedTable):
+ conn.execute("select * from doesnotexist").fetchone()
+ with pytest.raises(e.PipelineAborted):
+ conn.execute("select 'aborted'").fetchone()
+ # Sync restore the connection in usable state.
+ p.sync()
+ c2 = conn.execute("select 2")
+
+ (r,) = c1.fetchone()
+ assert r == 1
+
+ (r,) = c2.fetchone()
+ assert r == 2
+
+
+def test_pipeline_commit_aborted(conn):
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+ conn.commit()
+
+
+def test_sync_syncs_results(conn):
+ with conn.pipeline() as p:
+ cur = conn.execute("select 1")
+ assert cur.statusmessage is None
+ p.sync()
+ assert cur.statusmessage == "SELECT 1"
+
+
+def test_sync_syncs_errors(conn):
+ conn.autocommit = True
+ with conn.pipeline() as p:
+ conn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ p.sync()
+
+
+@pipeline_aborted
+def test_errors_raised_on_commit(conn):
+ with conn.pipeline():
+ conn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ conn.commit()
+ conn.rollback()
+ cur1 = conn.execute("select 1")
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+def test_errors_raised_on_transaction_exit(conn):
+ here = False
+ with conn.pipeline():
+ with pytest.raises(e.UndefinedTable):
+ with conn.transaction():
+ conn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = conn.execute("select 1")
+ assert here
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+def test_errors_raised_on_nested_transaction_exit(conn):
+ here = False
+ with conn.pipeline():
+ with conn.transaction():
+ with pytest.raises(e.UndefinedTable):
+ with conn.transaction():
+ conn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = conn.execute("select 1")
+ assert here
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+def test_implicit_transaction(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+ conn.execute("select 'before'")
+ # Transaction is ACTIVE because previous command is not completed
+ # since we have not fetched its results.
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ # Upon entering the nested pipeline through "with transaction():", a
+ # sync() is emitted to restore the transaction state to IDLE, as
+ # expected to emit a BEGIN.
+ with conn.transaction():
+ conn.execute("select 'tx'")
+ cur = conn.execute("select 'after'")
+ assert cur.fetchone() == ("after",)
+
+
+@pytest.mark.crdb_skip("deferrable")
+def test_error_on_commit(conn):
+ conn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ conn.commit()
+
+ with conn.pipeline():
+ conn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ conn.commit()
+ cur1 = conn.execute("select 1")
+ cur2 = conn.execute("select 2")
+
+ assert cur1.fetchone() == (1,)
+ assert cur2.fetchone() == (2,)
+
+
+def test_fetch_no_result(conn):
+ with conn.pipeline():
+ cur = conn.cursor()
+ with pytest.raises(e.ProgrammingError):
+ cur.fetchone()
+
+
+def test_executemany(conn):
+ conn.autocommit = True
+ conn.execute("drop table if exists execmanypipeline")
+ conn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ with conn.pipeline(), conn.cursor() as cur:
+ cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning num",
+ [(10,), (20,)],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert cur.fetchone() == (10,)
+ assert cur.nextset()
+ assert cur.fetchone() == (20,)
+ assert cur.nextset() is None
+
+
+def test_executemany_no_returning(conn):
+ conn.autocommit = True
+ conn.execute("drop table if exists execmanypipelinenoreturning")
+ conn.execute(
+ "create unlogged table execmanypipelinenoreturning ("
+ " id serial primary key, num integer)"
+ )
+ with conn.pipeline(), conn.cursor() as cur:
+ cur.executemany(
+ "insert into execmanypipelinenoreturning(num) values (%s)",
+ [(10,), (20,)],
+ returning=False,
+ )
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ cur.fetchone()
+ assert cur.nextset() is None
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ cur.fetchone()
+ assert cur.nextset() is None
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+def test_executemany_trace(conn, trace):
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("create temp table trace (id int)")
+ t = trace.trace(conn)
+ with conn.pipeline():
+ cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ conn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"]
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+def test_executemany_trace_returning(conn, trace):
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("create temp table trace (id int)")
+ t = trace.trace(conn)
+ with conn.pipeline():
+ cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ conn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"] * 3
+ assert items[-2].direction == "F" # last 2 items are F B
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+def test_prepared(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ c1 = conn.execute("select %s::int", [10], prepare=True)
+ c2 = conn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+
+ (r,) = c1.fetchone()
+ assert r == 10
+
+ (r,) = c2.fetchone()
+ assert r == 1
+
+
+def test_auto_prepare(conn):
+ conn.autocommit = True
+ conn.prepared_threshold = 5
+ with conn.pipeline():
+ cursors = [
+ conn.execute("select count(*) from pg_prepared_statements where name != ''")
+ for i in range(10)
+ ]
+
+ assert len(conn._prepared._names) == 1
+
+ res = [c.fetchone()[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_transaction(conn):
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ with conn.pipeline():
+ with conn.transaction():
+ cur = conn.execute("select 'tx'")
+
+ (r,) = cur.fetchone()
+ assert r == "tx"
+
+ with conn.transaction():
+ cur = conn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = cur.fetchone()
+ assert r == "rb"
+
+ assert not notices
+
+
+def test_transaction_nested(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ outer = conn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ inner = conn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = outer.fetchone()
+ assert r == "outer"
+ (r,) = inner.fetchone()
+ assert r == "inner"
+
+
+def test_transaction_nested_no_statement(conn):
+ with conn.pipeline():
+ with conn.transaction():
+ with conn.transaction():
+ cur = conn.execute("select 1")
+
+ (r,) = cur.fetchone()
+ assert r == 1
+
+
+def test_outer_transaction(conn):
+ with conn.transaction():
+ conn.execute("drop table if exists outertx")
+ with conn.transaction():
+ with conn.pipeline():
+ conn.execute("create table outertx as (select 1)")
+ cur = conn.execute("select * from outertx")
+ (r,) = cur.fetchone()
+ assert r == 1
+ cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'")
+ assert cur.fetchone()[0] == 1
+
+
+def test_outer_transaction_error(conn):
+ with conn.transaction():
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ with conn.pipeline():
+ conn.execute("select error")
+ conn.execute("create table voila ()")
+
+
+def test_rollback_explicit(conn):
+ conn.autocommit = True
+ with conn.pipeline():
+ with pytest.raises(e.DivisionByZero):
+ cur = conn.execute("select 1 / %s", [0])
+ cur.fetchone()
+ conn.rollback()
+ conn.execute("select 1")
+
+
+def test_rollback_transaction(conn):
+ conn.autocommit = True
+ with pytest.raises(e.DivisionByZero):
+ with conn.pipeline():
+ with conn.transaction():
+ cur = conn.execute("select 1 / %s", [0])
+ cur.fetchone()
+ conn.execute("select 1")
+
+
+def test_message_0x33(conn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ conn.autocommit = True
+ with conn.pipeline():
+ cur = conn.execute("select 'test'")
+ assert cur.fetchone() == ("test",)
+
+ assert not notices
+
+
+def test_transaction_state_implicit_begin(conn, trace):
+ # Regression test to ensure that the transaction state is correct after
+ # the implicit BEGIN statement (in non-autocommit mode).
+ notices = []
+ conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+ t = trace.trace(conn)
+ with conn.pipeline():
+ conn.execute("select 'x'").fetchone()
+ conn.execute("select 'y'")
+ assert not notices
+ assert [
+ e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0]
+ ] == [b' "" "BEGIN" 0']
+
+
+def test_concurrency(conn):
+ with conn.transaction():
+ conn.execute("drop table if exists pipeline_concurrency")
+ conn.execute("drop table if exists accessed")
+ with conn.transaction():
+ conn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ conn.execute("create unlogged table accessed as (select now() as value)")
+
+ def update(value):
+ cur = conn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning value",
+ (value,),
+ )
+ conn.execute("update accessed set value = now()")
+ return cur
+
+ conn.autocommit = True
+
+ (before,) = conn.execute("select value from accessed").fetchone()
+
+ values = range(1, 10)
+ with conn.pipeline():
+ with concurrent.futures.ThreadPoolExecutor() as e:
+ cursors = e.map(update, values, timeout=len(values))
+ assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+
+ (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+ assert s == sum(values)
+ (after,) = conn.execute("select value from accessed").fetchone()
+ assert after > before
diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py
new file mode 100644
index 0000000..2e743cf
--- /dev/null
+++ b/tests/test_pipeline_async.py
@@ -0,0 +1,586 @@
+import asyncio
+import logging
+from typing import Any
+from operator import attrgetter
+from itertools import groupby
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import errors as e
+
+from .test_pipeline import pipeline_aborted
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.pipeline,
+ pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"),
+]
+
+
+async def test_repr(aconn):
+ async with aconn.pipeline() as p:
+ assert "psycopg.AsyncPipeline" in repr(p)
+ assert "[IDLE, pipeline=ON]" in repr(p)
+
+ await aconn.close()
+ assert "[BAD]" in repr(p)
+
+
+async def test_connection_closed(aconn):
+ await aconn.close()
+ with pytest.raises(e.OperationalError):
+ async with aconn.pipeline():
+ pass
+
+
+async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None:
+ assert aconn._pipeline is None
+ async with aconn.pipeline() as p:
+ assert aconn._pipeline is p
+ assert p.status == pq.PipelineStatus.ON
+ assert p.status == pq.PipelineStatus.OFF
+ assert not aconn._pipeline
+
+
+async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None:
+ async with aconn.pipeline() as p1:
+ async with aconn.pipeline() as p2:
+ assert p2 is p1
+ assert p1.status == pq.PipelineStatus.ON
+ assert p2 is p1
+ assert p2.status == pq.PipelineStatus.ON
+ assert aconn._pipeline is None
+ assert p1.status == pq.PipelineStatus.OFF
+
+
+async def test_pipeline_broken_conn_exit(aconn: psycopg.AsyncConnection[Any]) -> None:
+ with pytest.raises(e.OperationalError):
+ async with aconn.pipeline():
+ await aconn.execute("select 1")
+ await aconn.close()
+ closed = True
+
+ assert closed
+
+
+async def test_pipeline_exit_error_noclobber(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.pipeline():
+ await aconn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 1
+
+
+async def test_pipeline_exit_error_noclobber_nested(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.pipeline():
+ async with aconn.pipeline():
+ await aconn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
+async def test_pipeline_exit_sync_trace(aconn, trace):
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ pass
+ await aconn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 1
+
+
+async def test_pipeline_nested_sync_trace(aconn, trace):
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ async with aconn.pipeline():
+ pass
+ await aconn.close()
+ assert len([i for i in t if i.type == "Sync"]) == 2
+
+
+async def test_cursor_stream(aconn):
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ with pytest.raises(psycopg.ProgrammingError):
+ await cur.stream("select 1").__anext__()
+
+
+async def test_server_cursor(aconn):
+ async with aconn.cursor(name="pipeline") as cur, aconn.pipeline():
+ with pytest.raises(psycopg.NotSupportedError):
+ await cur.execute("select 1")
+
+
+async def test_cannot_insert_multiple_commands(aconn):
+ with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)):
+ async with aconn.pipeline():
+ await aconn.execute("select 1; select 2")
+
+
+async def test_copy(aconn):
+ async with aconn.pipeline():
+ cur = aconn.cursor()
+ with pytest.raises(e.NotSupportedError):
+ async with cur.copy("copy (select 1) to stdout") as copy:
+ await copy.read()
+
+
+async def test_pipeline_processed_at_exit(aconn):
+ async with aconn.cursor() as cur:
+ async with aconn.pipeline() as p:
+ await cur.execute("select 1")
+
+ assert len(p.result_queue) == 1
+
+ assert await cur.fetchone() == (1,)
+
+
+async def test_pipeline_errors_processed_at_exit(aconn):
+ await aconn.set_autocommit(True)
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.pipeline():
+ await aconn.execute("select * from nosuchtable")
+ await aconn.execute("create table voila ()")
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = %s", ("voila",)
+ )
+ (count,) = await cur.fetchone()
+ assert count == 0
+
+
+async def test_pipeline(aconn):
+ async with aconn.pipeline() as p:
+ c1 = aconn.cursor()
+ c2 = aconn.cursor()
+ await c1.execute("select 1")
+ await c2.execute("select 2")
+
+ assert len(p.result_queue) == 2
+
+ (r1,) = await c1.fetchone()
+ assert r1 == 1
+
+ (r2,) = await c2.fetchone()
+ assert r2 == 2
+
+
+async def test_autocommit(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline(), aconn.cursor() as c:
+ await c.execute("select 1")
+
+ (r,) = await c.fetchone()
+ assert r == 1
+
+
+async def test_pipeline_aborted(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline() as p:
+ c1 = await aconn.execute("select 1")
+ with pytest.raises(e.UndefinedTable):
+ await (await aconn.execute("select * from doesnotexist")).fetchone()
+ with pytest.raises(e.PipelineAborted):
+ await (await aconn.execute("select 'aborted'")).fetchone()
+ # Sync restore the connection in usable state.
+ await p.sync()
+ c2 = await aconn.execute("select 2")
+
+ (r,) = await c1.fetchone()
+ assert r == 1
+
+ (r,) = await c2.fetchone()
+ assert r == 2
+
+
+async def test_pipeline_commit_aborted(aconn):
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+ await aconn.commit()
+
+
+async def test_sync_syncs_results(aconn):
+ async with aconn.pipeline() as p:
+ cur = await aconn.execute("select 1")
+ assert cur.statusmessage is None
+ await p.sync()
+ assert cur.statusmessage == "SELECT 1"
+
+
+async def test_sync_syncs_errors(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline() as p:
+ await aconn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ await p.sync()
+
+
+@pipeline_aborted
+async def test_errors_raised_on_commit(aconn):
+ async with aconn.pipeline():
+ await aconn.execute("select 1 from nosuchtable")
+ with pytest.raises(e.UndefinedTable):
+ await aconn.commit()
+ await aconn.rollback()
+ cur1 = await aconn.execute("select 1")
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+async def test_errors_raised_on_transaction_exit(aconn):
+ here = False
+ async with aconn.pipeline():
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.transaction():
+ await aconn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = await aconn.execute("select 1")
+ assert here
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+@pytest.mark.flakey("assert fails randomly in CI blocking release")
+async def test_errors_raised_on_nested_transaction_exit(aconn):
+ here = False
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ with pytest.raises(e.UndefinedTable):
+ async with aconn.transaction():
+ await aconn.execute("select 1 from nosuchtable")
+ here = True
+ cur1 = await aconn.execute("select 1")
+ assert here
+ cur2 = await aconn.execute("select 2")
+
+ assert await cur1.fetchone() == (1,)
+ assert await cur2.fetchone() == (2,)
+
+
+async def test_implicit_transaction(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+ await aconn.execute("select 'before'")
+ # Transaction is ACTIVE because previous command is not completed
+ # since we have not fetched its results.
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE
+ # Upon entering the nested pipeline through "with transaction():", a
+ # sync() is emitted to restore the transaction state to IDLE, as
+ # expected to emit a BEGIN.
+ async with aconn.transaction():
+ await aconn.execute("select 'tx'")
+ cur = await aconn.execute("select 'after'")
+ assert await cur.fetchone() == ("after",)
+
+
+@pytest.mark.crdb_skip("deferrable")
+async def test_error_on_commit(aconn):
+ await aconn.execute(
+ """
+ drop table if exists selfref;
+ create table selfref (
+ x serial primary key,
+ y int references selfref (x) deferrable initially deferred)
+ """
+ )
+ await aconn.commit()
+
+ async with aconn.pipeline():
+ await aconn.execute("insert into selfref (y) values (-1)")
+ with pytest.raises(e.ForeignKeyViolation):
+ await aconn.commit()
+ cur1 = await aconn.execute("select 1")
+ cur2 = await aconn.execute("select 2")
+
+ assert (await cur1.fetchone()) == (1,)
+ assert (await cur2.fetchone()) == (2,)
+
+
+async def test_fetch_no_result(aconn):
+ async with aconn.pipeline():
+ cur = aconn.cursor()
+ with pytest.raises(e.ProgrammingError):
+ await cur.fetchone()
+
+
+async def test_executemany(aconn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("drop table if exists execmanypipeline")
+ await aconn.execute(
+ "create unlogged table execmanypipeline ("
+ " id serial primary key, num integer)"
+ )
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ await cur.executemany(
+ "insert into execmanypipeline(num) values (%s) returning num",
+ [(10,), (20,)],
+ returning=True,
+ )
+ assert cur.rowcount == 2
+ assert (await cur.fetchone()) == (10,)
+ assert cur.nextset()
+ assert (await cur.fetchone()) == (20,)
+ assert cur.nextset() is None
+
+
+async def test_executemany_no_returning(aconn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("drop table if exists execmanypipelinenoreturning")
+ await aconn.execute(
+ "create unlogged table execmanypipelinenoreturning ("
+ " id serial primary key, num integer)"
+ )
+ async with aconn.pipeline(), aconn.cursor() as cur:
+ await cur.executemany(
+ "insert into execmanypipelinenoreturning(num) values (%s)",
+ [(10,), (20,)],
+ returning=False,
+ )
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ await cur.fetchone()
+ assert cur.nextset() is None
+ with pytest.raises(e.ProgrammingError, match="no result available"):
+ await cur.fetchone()
+ assert cur.nextset() is None
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+async def test_executemany_trace(aconn, trace):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("create temp table trace (id int)")
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)])
+ await aconn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"]
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+@pytest.mark.crdb("skip", reason="temp tables")
+async def test_executemany_trace_returning(aconn, trace):
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("create temp table trace (id int)")
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ await cur.executemany(
+ "insert into trace (id) values (%s)", [(10,), (20,)], returning=True
+ )
+ await aconn.close()
+ items = list(t)
+ assert items[-1].type == "Terminate"
+ del items[-1]
+ roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ assert roundtrips == ["F", "B"] * 3
+ assert items[-2].direction == "F" # last 2 items are F B
+ assert len([i for i in items if i.type == "Sync"]) == 1
+
+
+async def test_prepared(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ c1 = await aconn.execute("select %s::int", [10], prepare=True)
+ c2 = await aconn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+
+ (r,) = await c1.fetchone()
+ assert r == 10
+
+ (r,) = await c2.fetchone()
+ assert r == 1
+
+
+async def test_auto_prepare(aconn):
+ aconn.prepared_threshold = 5
+ async with aconn.pipeline():
+ cursors = [
+ await aconn.execute(
+ "select count(*) from pg_prepared_statements where name != ''"
+ )
+ for i in range(10)
+ ]
+
+ assert len(aconn._prepared._names) == 1
+
+ res = [(await c.fetchone())[0] for c in cursors]
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_transaction(aconn):
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'tx'")
+
+ (r,) = await cur.fetchone()
+ assert r == "tx"
+
+ async with aconn.transaction():
+ cur = await aconn.execute("select 'rb'")
+ raise psycopg.Rollback()
+
+ (r,) = await cur.fetchone()
+ assert r == "rb"
+
+ assert not notices
+
+
+async def test_transaction_nested(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ outer = await aconn.execute("select 'outer'")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.transaction():
+ inner = await aconn.execute("select 'inner'")
+ 1 / 0
+
+ (r,) = await outer.fetchone()
+ assert r == "outer"
+ (r,) = await inner.fetchone()
+ assert r == "inner"
+
+
+async def test_transaction_nested_no_statement(aconn):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 1")
+
+ (r,) = await cur.fetchone()
+ assert r == 1
+
+
+async def test_outer_transaction(aconn):
+ async with aconn.transaction():
+ await aconn.execute("drop table if exists outertx")
+ async with aconn.transaction():
+ async with aconn.pipeline():
+ await aconn.execute("create table outertx as (select 1)")
+ cur = await aconn.execute("select * from outertx")
+ (r,) = await cur.fetchone()
+ assert r == 1
+ cur = await aconn.execute(
+ "select count(*) from pg_tables where tablename = 'outertx'"
+ )
+ assert (await cur.fetchone())[0] == 1
+
+
+async def test_outer_transaction_error(aconn):
+ async with aconn.transaction():
+ with pytest.raises((e.UndefinedColumn, e.OperationalError)):
+ async with aconn.pipeline():
+ await aconn.execute("select error")
+ await aconn.execute("create table voila ()")
+
+
+async def test_rollback_explicit(aconn):
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ with pytest.raises(e.DivisionByZero):
+ cur = await aconn.execute("select 1 / %s", [0])
+ await cur.fetchone()
+ await aconn.rollback()
+ await aconn.execute("select 1")
+
+
+async def test_rollback_transaction(aconn):
+ await aconn.set_autocommit(True)
+ with pytest.raises(e.DivisionByZero):
+ async with aconn.pipeline():
+ async with aconn.transaction():
+ cur = await aconn.execute("select 1 / %s", [0])
+ await cur.fetchone()
+ await aconn.execute("select 1")
+
+
+async def test_message_0x33(aconn):
+ # https://github.com/psycopg/psycopg/issues/314
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
+ await aconn.set_autocommit(True)
+ async with aconn.pipeline():
+ cur = await aconn.execute("select 'test'")
+ assert (await cur.fetchone()) == ("test",)
+
+ assert not notices
+
+
+async def test_transaction_state_implicit_begin(aconn, trace):
+ # Regression test to ensure that the transaction state is correct after
+ # the implicit BEGIN statement (in non-autocommit mode).
+ notices = []
+ aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+ t = trace.trace(aconn)
+ async with aconn.pipeline():
+ await (await aconn.execute("select 'x'")).fetchone()
+ await aconn.execute("select 'y'")
+ assert not notices
+ assert [
+ e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0]
+ ] == [b' "" "BEGIN" 0']
+
+
+async def test_concurrency(aconn):
+ async with aconn.transaction():
+ await aconn.execute("drop table if exists pipeline_concurrency")
+ await aconn.execute("drop table if exists accessed")
+ async with aconn.transaction():
+ await aconn.execute(
+ "create unlogged table pipeline_concurrency ("
+ " id serial primary key,"
+ " value integer"
+ ")"
+ )
+ await aconn.execute("create unlogged table accessed as (select now() as value)")
+
+ async def update(value):
+ cur = await aconn.execute(
+ "insert into pipeline_concurrency(value) values (%s) returning value",
+ (value,),
+ )
+ await aconn.execute("update accessed set value = now()")
+ return cur
+
+ await aconn.set_autocommit(True)
+
+ (before,) = await (await aconn.execute("select value from accessed")).fetchone()
+
+ values = range(1, 10)
+ async with aconn.pipeline():
+ cursors = await asyncio.wait_for(
+ asyncio.gather(*[update(value) for value in values]),
+ timeout=len(values),
+ )
+
+ assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+
+ (s,) = await (
+ await aconn.execute("select sum(value) from pipeline_concurrency")
+ ).fetchone()
+ assert s == sum(values)
+ (after,) = await (await aconn.execute("select value from accessed")).fetchone()
+ assert after > before
diff --git a/tests/test_prepared.py b/tests/test_prepared.py
new file mode 100644
index 0000000..56c580a
--- /dev/null
+++ b/tests/test_prepared.py
@@ -0,0 +1,277 @@
+"""
+Prepared statements tests
+"""
+
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg.rows import namedtuple_row
+
+
+@pytest.mark.parametrize("value", [None, 0, 3])
+def test_prepare_threshold_init(conn_cls, dsn, value):
+ with conn_cls.connect(dsn, prepare_threshold=value) as conn:
+ assert conn.prepare_threshold == value
+
+
+def test_dont_prepare(conn):
+ cur = conn.cursor()
+ for i in range(10):
+ cur.execute("select %s::int", [i], prepare=False)
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_do_prepare(conn):
+ cur = conn.cursor()
+ cur.execute("select %s::int", [10], prepare=True)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_auto_prepare(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select %s::int", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_dont_prepare_conn(conn):
+ for i in range(10):
+ conn.execute("select %s::int", [i], prepare=False)
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_do_prepare_conn(conn):
+ conn.execute("select %s::int", [10], prepare=True)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_auto_prepare_conn(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select %s", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+def test_prepare_disable(conn):
+ conn.prepare_threshold = None
+ res = []
+ for i in range(10):
+ conn.execute("select %s", [0])
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+ assert not conn._prepared._names
+ assert not conn._prepared._counts
+
+
+def test_no_prepare_multi(conn):
+ res = []
+ for i in range(10):
+ conn.execute("select 1; select 2")
+ stmts = get_prepared_statements(conn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+
+
+def test_no_prepare_multi_with_drop(conn):
+ conn.execute("select 1", prepare=True)
+
+ for i in range(10):
+ conn.execute("drop table if exists noprep; create table noprep()")
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+def test_no_prepare_error(conn):
+ conn.autocommit = True
+ for i in range(10):
+ with pytest.raises(conn.ProgrammingError):
+ conn.execute("select wat")
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_no_prepare ()",
+ pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
+ "set timezone = utc",
+ "select num from prepared_test",
+ "insert into prepared_test (num) values (1)",
+ "update prepared_test set num = num * 2",
+ "delete from prepared_test where num > 10",
+ ],
+)
+def test_misc_statement(conn, query):
+ conn.execute("create table prepared_test (num int)", prepare=False)
+ conn.prepare_threshold = 0
+ conn.execute(query)
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+
+
+def test_params_types(conn):
+ conn.execute(
+ "select %s, %s, %s",
+ [dt.date(2020, 12, 10), 42, Decimal(42)],
+ prepare=True,
+ )
+ stmts = get_prepared_statements(conn)
+ want = [stmt.parameter_types for stmt in stmts]
+ assert want == [["date", "smallint", "numeric"]]
+
+
+def test_evict_lru(conn):
+ conn.prepared_max = 5
+ for i in range(10):
+ conn.execute("select 'a'")
+ conn.execute(f"select {i}")
+
+ assert len(conn._prepared._names) == 1
+ assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
+ for i in [9, 8, 7, 6]:
+ assert conn._prepared._counts[f"select {i}".encode(), ()] == 1
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 1
+ assert stmts[0].statement == "select 'a'"
+
+
+def test_evict_lru_deallocate(conn):
+ conn.prepared_max = 5
+ conn.prepare_threshold = 0
+ for i in range(10):
+ conn.execute("select 'a'")
+ conn.execute(f"select {i}")
+
+ assert len(conn._prepared._names) == 5
+ for j in [9, 8, 7, 6, "'a'"]:
+ name = conn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
+
+ stmts = get_prepared_statements(conn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.statement for stmt in stmts]
+ assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
+
+
+def test_different_types(conn):
+ conn.prepare_threshold = 0
+ conn.execute("select %s", [None])
+ conn.execute("select %s", [dt.date(2000, 1, 1)])
+ conn.execute("select %s", [42])
+ conn.execute("select %s", [41])
+ conn.execute("select %s", [dt.date(2000, 1, 2)])
+
+ stmts = get_prepared_statements(conn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["text"], ["date"], ["smallint"]]
+
+
+def test_untyped_json(conn):
+ conn.prepare_threshold = 1
+ conn.execute("create table testjson(data jsonb)")
+
+ for i in range(2):
+ conn.execute("insert into testjson (data) values (%s)", ["{}"])
+
+ stmts = get_prepared_statements(conn)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["jsonb"]]
+
+
+def test_change_type_execute(conn):
+ conn.prepare_threshold = 0
+ for i in range(3):
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ conn.rollback()
+
+
+def test_change_type_executemany(conn):
+ for i in range(3):
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().executemany(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+ )
+ conn.rollback()
+
+
+@pytest.mark.crdb("skip", reason="can't re-create a type")
+def test_change_type(conn):
+ conn.prepare_threshold = 0
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ conn.execute("DROP TABLE preptable")
+ conn.execute("DROP TYPE prepenum")
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+
+ stmts = get_prepared_statements(conn)
+ assert len(stmts) == 3
+
+
+def test_change_type_savepoint(conn):
+ conn.prepare_threshold = 0
+ with conn.transaction():
+ for i in range(3):
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
+ conn.cursor().execute(
+ "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])",
+ {"enum_col": ["foo"]},
+ )
+ raise ZeroDivisionError()
+
+
+def get_prepared_statements(conn):
+ cur = conn.cursor(row_factory=namedtuple_row)
+ cur.execute(
+ # CRDB has 'PREPARE name AS' in the statement.
+ r"""
+select name,
+ regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+ prepare_time,
+ parameter_types
+from pg_prepared_statements
+where name != ''
+ """,
+ prepare=False,
+ )
+ return cur.fetchall()
diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py
new file mode 100644
index 0000000..84d948f
--- /dev/null
+++ b/tests/test_prepared_async.py
@@ -0,0 +1,207 @@
+"""
+Prepared statements tests on async connections
+"""
+
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.parametrize("value", [None, 0, 3])
+async def test_prepare_threshold_init(aconn_cls, dsn, value):
+ async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn:
+ assert conn.prepare_threshold == value
+
+
+async def test_dont_prepare(aconn):
+ cur = aconn.cursor()
+ for i in range(10):
+ await cur.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s::int", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_dont_prepare_conn(aconn):
+ for i in range(10):
+ await aconn.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare_conn(aconn):
+ await aconn.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare_conn(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_prepare_disable(aconn):
+ aconn.prepare_threshold = None
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+ assert not aconn._prepared._names
+ assert not aconn._prepared._counts
+
+
+async def test_no_prepare_multi(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select 1; select 2")
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+
+
+async def test_no_prepare_error(aconn):
+ await aconn.set_autocommit(True)
+ for i in range(10):
+ with pytest.raises(aconn.ProgrammingError):
+ await aconn.execute("select wat")
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_no_prepare ()",
+ pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
+ "set timezone = utc",
+ "select num from prepared_test",
+ "insert into prepared_test (num) values (1)",
+ "update prepared_test set num = num * 2",
+ "delete from prepared_test where num > 10",
+ ],
+)
+async def test_misc_statement(aconn, query):
+ await aconn.execute("create table prepared_test (num int)", prepare=False)
+ aconn.prepare_threshold = 0
+ await aconn.execute(query)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_params_types(aconn):
+ await aconn.execute(
+ "select %s, %s, %s",
+ [dt.date(2020, 12, 10), 42, Decimal(42)],
+ prepare=True,
+ )
+ stmts = await get_prepared_statements(aconn)
+ want = [stmt.parameter_types for stmt in stmts]
+ assert want == [["date", "smallint", "numeric"]]
+
+
+async def test_evict_lru(aconn):
+ aconn.prepared_max = 5
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 1
+ assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
+ for i in [9, 8, 7, 6]:
+ assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+ assert stmts[0].statement == "select 'a'"
+
+
+async def test_evict_lru_deallocate(aconn):
+ aconn.prepared_max = 5
+ aconn.prepare_threshold = 0
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 5
+ for j in [9, 8, 7, 6, "'a'"]:
+ name = aconn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.statement for stmt in stmts]
+ assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
+
+
+async def test_different_types(aconn):
+ aconn.prepare_threshold = 0
+ await aconn.execute("select %s", [None])
+ await aconn.execute("select %s", [dt.date(2000, 1, 1)])
+ await aconn.execute("select %s", [42])
+ await aconn.execute("select %s", [41])
+ await aconn.execute("select %s", [dt.date(2000, 1, 2)])
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["text"], ["date"], ["smallint"]]
+
+
+async def test_untyped_json(aconn):
+ aconn.prepare_threshold = 1
+ await aconn.execute("create table testjson(data jsonb)")
+ for i in range(2):
+ await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
+
+ stmts = await get_prepared_statements(aconn)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["jsonb"]]
+
+
+async def get_prepared_statements(aconn):
+ cur = aconn.cursor(row_factory=namedtuple_row)
+ await cur.execute(
+ r"""
+select name,
+ regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+ prepare_time,
+ parameter_types
+from pg_prepared_statements
+where name != ''
+ """,
+ prepare=False,
+ )
+ return await cur.fetchall()
diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py
new file mode 100644
index 0000000..82a5d73
--- /dev/null
+++ b/tests/test_psycopg_dbapi20.py
@@ -0,0 +1,164 @@
+import pytest
+import datetime as dt
+from typing import Any, Dict
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from . import dbapi20
+from . import dbapi20_tpc
+
+
+@pytest.fixture(scope="class")
+def with_dsn(request, session_dsn):
+ request.cls.connect_args = (session_dsn,)
+
+
+@pytest.mark.usefixtures("with_dsn")
+class PsycopgTests(dbapi20.DatabaseAPI20Test):
+ driver = psycopg
+ # connect_args = () # set by the fixture
+ connect_kw_args: Dict[str, Any] = {}
+
+ def test_nextset(self):
+ # tested elsewhere
+ pass
+
+ def test_setoutputsize(self):
+ # no-op
+ pass
+
+
+@pytest.mark.usefixtures("tpc")
+@pytest.mark.usefixtures("with_dsn")
+class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests):
+ driver = psycopg
+ connect_args = () # set by the fixture
+
+ def connect(self):
+ return psycopg.connect(*self.connect_args)
+
+
+# Shut up warnings
+PsycopgTests.failUnless = PsycopgTests.assertTrue
+PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual
+
+
+@pytest.mark.parametrize(
+ "typename, singleton",
+ [
+ ("bytea", "BINARY"),
+ ("date", "DATETIME"),
+ ("timestamp without time zone", "DATETIME"),
+ ("timestamp with time zone", "DATETIME"),
+ ("time without time zone", "DATETIME"),
+ ("time with time zone", "DATETIME"),
+ ("interval", "DATETIME"),
+ ("integer", "NUMBER"),
+ ("smallint", "NUMBER"),
+ ("bigint", "NUMBER"),
+ ("real", "NUMBER"),
+ ("double precision", "NUMBER"),
+ ("numeric", "NUMBER"),
+ ("decimal", "NUMBER"),
+ ("oid", "ROWID"),
+ ("varchar", "STRING"),
+ ("char", "STRING"),
+ ("text", "STRING"),
+ ],
+)
+def test_singletons(conn, typename, singleton):
+ singleton = getattr(psycopg, singleton)
+ cur = conn.cursor()
+ cur.execute(f"select null::{typename}")
+ oid = cur.description[0].type_code
+ assert singleton == oid
+ assert oid == singleton
+ assert singleton != oid + 10000
+ assert oid + 10000 != singleton
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [
+ (0, "1970-01-01T00:00:00.000000+0000"),
+ (1273173119.99992, "2010-05-06T14:11:59.999920-0500"),
+ ],
+)
+def test_timestamp_from_ticks(ticks, want):
+ s = psycopg.TimestampFromTicks(ticks)
+ want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z")
+ assert s == want
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [
+ (0, "1970-01-01"),
+ # Returned date is local
+ (1273173119.99992, ["2010-05-06", "2010-05-07"]),
+ ],
+)
+def test_date_from_ticks(ticks, want):
+ s = psycopg.DateFromTicks(ticks)
+ if isinstance(want, str):
+ want = [want]
+ want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want]
+ assert s in want
+
+
+@pytest.mark.parametrize(
+ "ticks, want",
+ [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")],
+)
+def test_time_from_ticks(ticks, want):
+ s = psycopg.TimeFromTicks(ticks)
+ want = dt.datetime.strptime(want, "%H:%M:%S.%f").time()
+ assert s.replace(hour=0) == want
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, want",
+ [
+ ((), {}, ""),
+ (("",), {}, ""),
+ (("host=foo user=bar",), {}, "host=foo user=bar"),
+ (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (
+ ("host=foo port=5432",),
+ {"host": "qux", "user": "joe"},
+ "host=qux user=joe port=5432",
+ ),
+ (("host=foo",), {"user": None}, "host=foo"),
+ ],
+)
+def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
+ the_conninfo: str
+
+ def fake_connect(conninfo):
+ nonlocal the_conninfo
+ the_conninfo = conninfo
+ return pgconn
+ yield
+
+ monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+ conn = psycopg.connect(*args, **kwargs)
+ assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+ conn.close()
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, exctype",
+ [
+ (("host=foo", "host=bar"), {}, TypeError),
+ (("", ""), {}, TypeError),
+ ((), {"nosuchparam": 42}, psycopg.ProgrammingError),
+ ],
+)
+def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype):
+ def fake_connect(conninfo):
+ return pgconn
+ yield
+
+ with pytest.raises(exctype):
+ psycopg.connect(*args, **kwargs)
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644
index 0000000..7263a80
--- /dev/null
+++ b/tests/test_query.py
@@ -0,0 +1,162 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg.adapt import Transformer, PyFormat
+from psycopg._queries import PostgresQuery, _split_query
+
+
+@pytest.mark.parametrize(
+ "input, want",
+ [
+ (b"", [(b"", 0, PyFormat.AUTO)]),
+ (b"foo bar", [(b"foo bar", 0, PyFormat.AUTO)]),
+ (b"foo %% bar", [(b"foo % bar", 0, PyFormat.AUTO)]),
+ (b"%s", [(b"", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]),
+ (b"%s foo", [(b"", 0, PyFormat.AUTO), (b" foo", 0, PyFormat.AUTO)]),
+ (b"%b foo", [(b"", 0, PyFormat.BINARY), (b" foo", 0, PyFormat.AUTO)]),
+ (b"foo %s", [(b"foo ", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]),
+ (
+ b"foo %%%s bar",
+ [(b"foo %", 0, PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)],
+ ),
+ (
+ b"foo %(name)s bar",
+ [(b"foo ", "name", PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)],
+ ),
+ (
+ b"foo %(name)s %(name)b bar",
+ [
+ (b"foo ", "name", PyFormat.AUTO),
+ (b" ", "name", PyFormat.BINARY),
+ (b" bar", 0, PyFormat.AUTO),
+ ],
+ ),
+ (
+ b"foo %s%b bar %s baz",
+ [
+ (b"foo ", 0, PyFormat.AUTO),
+ (b"", 1, PyFormat.BINARY),
+ (b" bar ", 2, PyFormat.AUTO),
+ (b" baz", 0, PyFormat.AUTO),
+ ],
+ ),
+ ],
+)
+def test_split_query(input, want):
+ assert _split_query(input) == want
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ b"foo %d bar",
+ b"foo % bar",
+ b"foo %%% bar",
+ b"foo %(foo)d bar",
+ b"foo %(foo)s bar %s baz",
+ b"foo %(foo) bar",
+ b"foo %(foo bar",
+ b"3%2",
+ ],
+)
+def test_split_query_bad(input):
+ with pytest.raises(psycopg.ProgrammingError):
+ _split_query(input)
+
+
+@pytest.mark.parametrize(
+ "query, params, want, wformats, wparams",
+ [
+ (b"", None, b"", None, None),
+ (b"", [], b"", [], []),
+ (b"%%", [], b"%", [], []),
+ (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),
+ (
+ b"%t %% %t",
+ (1, 2),
+ b"$1 % $2",
+ [pq.Format.TEXT, pq.Format.TEXT],
+ [b"1", b"2"],
+ ),
+ (
+ b"%t %% %t",
+ ("a", 2),
+ b"$1 % $2",
+ [pq.Format.TEXT, pq.Format.TEXT],
+ [b"a", b"2"],
+ ),
+ ],
+)
+def test_pg_query_seq(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
+
+
+@pytest.mark.parametrize(
+ "query, params, want, wformats, wparams",
+ [
+ (b"", {}, b"", [], []),
+ (b"hello %%", {"a": 1}, b"hello %", [], []),
+ (
+ b"select %(hello)t",
+ {"hello": 1, "world": 2},
+ b"select $1",
+ [pq.Format.TEXT],
+ [b"1"],
+ ),
+ (
+ b"select %(hi)s %(there)s %(hi)s",
+ {"hi": 0, "there": "a"},
+ b"select $1 $2 $1",
+ [pq.Format.BINARY, pq.Format.TEXT],
+ [b"\x00" * 2, b"a"],
+ ),
+ ],
+)
+def test_pg_query_map(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"select %s", {"a": 1}),
+ (b"select %(name)s", [1]),
+ (b"select %s", "a"),
+ (b"select %s", 1),
+ (b"select %s", b"a"),
+ (b"select %s", set()),
+ ],
+)
+def test_pq_query_badtype(query, params):
+ pq = PostgresQuery(Transformer())
+ with pytest.raises(TypeError):
+ pq.convert(query, params)
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"", [1]),
+ (b"%s", []),
+ (b"%%", [1]),
+ (b"$1", [1]),
+ (b"select %(", {"a": 1}),
+ (b"select %(a", {"a": 1}),
+ (b"select %(a)", {"a": 1}),
+ (b"select %s %(hi)s", [1]),
+ (b"select %(hi)s %(hi)b", {"hi": 1}),
+ ],
+)
+def test_pq_query_badprog(query, params):
+ pq = PostgresQuery(Transformer())
+ with pytest.raises(psycopg.ProgrammingError):
+ pq.convert(query, params)
diff --git a/tests/test_rows.py b/tests/test_rows.py
new file mode 100644
index 0000000..5165b80
--- /dev/null
+++ b/tests/test_rows.py
@@ -0,0 +1,167 @@
+import pytest
+
+import psycopg
+from psycopg import rows
+
+from .utils import eur
+
+
+def test_tuple_row(conn):
+ conn.row_factory = rows.dict_row
+ assert conn.execute("select 1 as a").fetchone() == {"a": 1}
+ cur = conn.cursor(row_factory=rows.tuple_row)
+ row = cur.execute("select 1 as a").fetchone()
+ assert row == (1,)
+ assert type(row) is tuple
+ assert cur._make_row is tuple
+
+
+def test_dict_row(conn):
+ cur = conn.cursor(row_factory=rows.dict_row)
+ cur.execute("select 'bob' as name, 3 as id")
+ assert cur.fetchall() == [{"name": "bob", "id": 3}]
+
+ cur.execute("select 'a' as letter; select 1 as number")
+ assert cur.fetchall() == [{"letter": "a"}]
+ assert cur.nextset()
+ assert cur.fetchall() == [{"number": 1}]
+ assert not cur.nextset()
+
+
+def test_namedtuple_row(conn):
+ rows._make_nt.cache_clear()
+ cur = conn.cursor(row_factory=rows.namedtuple_row)
+ cur.execute("select 'bob' as name, 3 as id")
+ (person1,) = cur.fetchall()
+ assert f"{person1.name} {person1.id}" == "bob 3"
+
+ ci1 = rows._make_nt.cache_info()
+ assert ci1.hits == 0 and ci1.misses == 1
+
+ cur.execute("select 'alice' as name, 1 as id")
+ (person2,) = cur.fetchall()
+ assert type(person2) is type(person1)
+
+ ci2 = rows._make_nt.cache_info()
+ assert ci2.hits == 1 and ci2.misses == 1
+
+ cur.execute("select 'foo', 1 as id")
+ (r0,) = cur.fetchall()
+ assert r0.f_column_ == "foo"
+ assert r0.id == 1
+
+ cur.execute("select 'a' as letter; select 1 as number")
+ (r1,) = cur.fetchall()
+ assert r1.letter == "a"
+ assert cur.nextset()
+ (r2,) = cur.fetchall()
+ assert r2.number == 1
+ assert not cur.nextset()
+ assert type(r1) is not type(r2)
+
+ cur.execute(f'select 1 as üåäö, 2 as _, 3 as "123", 4 as "a-b", 5 as "{eur}eur"')
+ (r3,) = cur.fetchall()
+ assert r3.üåäö == 1
+ assert r3.f_ == 2
+ assert r3.f123 == 3
+ assert r3.a_b == 4
+ assert r3.f_eur == 5
+
+
+def test_class_row(conn):
+ cur = conn.cursor(row_factory=rows.class_row(Person))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ (p,) = cur.fetchall()
+ assert isinstance(p, Person)
+ assert p.first == "John"
+ assert p.last == "Doe"
+ assert p.age is None
+
+ for query in (
+ "select 'John' as first",
+ "select 'John' as first, 'Doe' as last, 42 as wat",
+ ):
+ cur.execute(query)
+ with pytest.raises(TypeError):
+ cur.fetchone()
+
+
+def test_args_row(conn):
+ cur = conn.cursor(row_factory=rows.args_row(argf))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ assert cur.fetchone() == "JohnDoe"
+
+
+def test_kwargs_row(conn):
+ cur = conn.cursor(row_factory=rows.kwargs_row(kwargf))
+ cur.execute("select 'John' as first, 'Doe' as last")
+ (p,) = cur.fetchall()
+ assert isinstance(p, Person)
+ assert p.first == "John"
+ assert p.last == "Doe"
+ assert p.age == 42
+
+
+@pytest.mark.parametrize(
+ "factory",
+ "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
+)
+def test_no_result(factory, conn):
+ cur = conn.cursor(row_factory=factory_from_name(factory))
+ cur.execute("reset search_path")
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.fetchone()
+
+
+@pytest.mark.crdb_skip("no col query")
+@pytest.mark.parametrize(
+ "factory", "tuple_row dict_row namedtuple_row args_row".split()
+)
+def test_no_column(factory, conn):
+ cur = conn.cursor(row_factory=factory_from_name(factory))
+ cur.execute("select")
+ recs = cur.fetchall()
+ assert len(recs) == 1
+ assert not recs[0]
+
+
+@pytest.mark.crdb("skip")
+def test_no_column_class_row(conn):
+ class Empty:
+ def __init__(self, x=10, y=20):
+ self.x = x
+ self.y = y
+
+ cur = conn.cursor(row_factory=rows.class_row(Empty))
+ cur.execute("select")
+ x = cur.fetchone()
+ assert isinstance(x, Empty)
+ assert x.x == 10
+ assert x.y == 20
+
+
+def factory_from_name(name):
+ factory = getattr(rows, name)
+ if factory is rows.class_row:
+ factory = factory(Person)
+ if factory is rows.args_row:
+ factory = factory(argf)
+ if factory is rows.kwargs_row:
+ factory = factory(argf)
+
+ return factory
+
+
+class Person:
+ def __init__(self, first, last, age=None):
+ self.first = first
+ self.last = last
+ self.age = age
+
+
+def argf(*args):
+ return "".join(map(str, args))
+
+
+def kwargf(**kwargs):
+ return Person(**kwargs, age=42)
diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py
new file mode 100644
index 0000000..f7b6c8e
--- /dev/null
+++ b/tests/test_server_cursor.py
@@ -0,0 +1,525 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = pytest.mark.crdb_skip("server-side cursor")
+
+
+def test_init_row_factory(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is conn
+ assert cur.row_factory is conn.row_factory
+
+ conn.row_factory = rows.dict_row
+
+ with psycopg.ServerCursor(conn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+def test_init_params(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+def test_funny_name(conn):
+ cur = conn.cursor("1-2-3")
+ cur.execute("select generate_series(1, 3) as bar")
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ cur.close()
+
+
+def test_repr(conn):
+ cur = conn.cursor("my-name")
+ assert "psycopg.ServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ cur.close()
+
+
+def test_connection(conn):
+ cur = conn.cursor("foo")
+ assert cur.connection is conn
+ cur.close()
+
+
+def test_description(conn):
+ cur = conn.cursor("foo")
+ assert cur.name == "foo"
+ cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ cur.close()
+
+
+def test_format(conn):
+ cur = conn.cursor("foo")
+ assert cur.format == Format.TEXT
+ cur.close()
+
+ cur = conn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ cur.close()
+
+
+def test_query_params(conn):
+ with conn.cursor("foo") as cur:
+ assert cur._query is None
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+def test_binary_cursor_execute(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ cur.close()
+
+
+def test_execute_binary(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ cur.execute("select generate_series(1, 1)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ cur.close()
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)", binary=False)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ cur.close()
+
+
+def test_close(conn, recwarn):
+ if conn.info.transaction_status == conn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ conn.execute("close foo")
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ cur.close()
+ assert cur.closed
+
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_idempotent(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ cur.fetchall()
+ cur.close()
+ cur.close()
+
+
+def test_close_broken_conn(conn):
+ cur = conn.cursor("foo")
+ conn.close()
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchall()
+
+
+def test_close_noop(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_on_error(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+ cur.close()
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_context(conn, recwarn):
+ recwarn.clear()
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_no_clobber(conn):
+ with pytest.raises(e.DivisionByZero):
+ with conn.cursor("foo") as cur:
+ cur.execute("select 1 / %s", (0,))
+ cur.fetchall()
+
+
+def test_warn_close(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+def test_execute_reuse(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert cur.fetchone() == (1,)
+
+ cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+def test_execute_error(conn, stmt):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ cur.execute(stmt)
+ cur.close()
+
+
+def test_executemany(conn):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ cur.executemany("select %s", [(1,), (2,)])
+ cur.close()
+
+
+def test_fetchone(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchone() == (2,)
+ assert cur.fetchone() is None
+
+
+def test_fetchmany(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert cur.fetchone() == (4,)
+ assert cur.fetchmany(3) == [(5,)]
+ assert cur.fetchmany(3) == []
+
+
+def test_fetchall(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.fetchall() == []
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchall() == [(2,), (3,)]
+ assert cur.fetchall() == []
+
+
+def test_nextset(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+def test_no_result(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert cur.fetchall() == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_standard_row_factory(conn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ with conn.cursor("foo", row_factory=row_factory) as cur:
+ cur.execute("select generate_series(1, 5) as bar")
+ assert getter(cur.fetchone()) == 1
+ assert list(map(getter, cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_row_factory(conn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ cur.execute("select generate_series(1, 3) as x")
+ recs = cur.fetchall()
+ cur.scroll(0, "absolute")
+ while True:
+ rec = cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"x": 1}
+ cur.close()
+
+
+def test_rownumber(conn):
+ cur = conn.cursor("foo")
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ cur.fetchall()
+ assert cur.rownumber == 42
+ cur.close()
+
+
+def test_iter(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = list(cur)
+ assert recs == [(1,), (2,), (3,)]
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ recs = list(cur)
+ assert recs == [(2,), (3,)]
+
+
+def test_iter_rownumber(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ for row in cur:
+ assert cur.rownumber == row[0]
+
+
+def test_itersize(conn, commands):
+ with conn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ commands.popall() # flush begin and other noise
+
+ list(cur)
+ cmds = commands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+def test_cant_scroll_by_default(conn):
+ cur = conn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ cur.scroll(0)
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scroll(conn):
+ cur = conn.cursor("tmp", scrollable=True)
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ cur.scroll(9, mode="wat")
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ for i in range(4, -1, -1):
+ curs.scroll(-1)
+ assert i == curs.fetchone()[0]
+ curs.scroll(-1)
+ curs.close()
+
+
+def test_non_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ curs.scroll(-1)
+ curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+def test_no_hold(conn, kwargs):
+ with conn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ curs.execute("select generate_series(0, 2)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+def test_hold(conn):
+ with conn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ curs.execute("select generate_series(0, 5)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ assert curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+def test_steal_cursor(conn, row_factory):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6) as s")
+
+ cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert cur2.fetchall() == [(5,), (6,)]
+ cur2.close()
+
+
+def test_stolen_cursor_close(conn):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()
+
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()
diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py
new file mode 100644
index 0000000..21b4345
--- /dev/null
+++ b/tests/test_server_cursor_async.py
@@ -0,0 +1,543 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("server-side cursor"),
+]
+
+
+async def test_init_row_factory(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is aconn
+ assert cur.row_factory is aconn.row_factory
+
+ aconn.row_factory = rows.dict_row
+
+ async with psycopg.AsyncServerCursor(aconn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "baz", row_factory=rows.namedtuple_row
+ ) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+async def test_init_params(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "bar", withhold=True, scrollable=False
+ ) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+async def test_funny_name(aconn):
+ cur = aconn.cursor("1-2-3")
+ await cur.execute("select generate_series(1, 3) as bar")
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ await cur.close()
+
+
+async def test_repr(aconn):
+ cur = aconn.cursor("my-name")
+ assert "psycopg.AsyncServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ await cur.close()
+
+
+async def test_connection(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.connection is aconn
+ await cur.close()
+
+
+async def test_description(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.name == "foo"
+ await cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ await cur.close()
+
+
+async def test_format(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.format == Format.TEXT
+ await cur.close()
+
+ cur = aconn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ await cur.close()
+
+
+async def test_query_params(aconn):
+ async with aconn.cursor("foo") as cur:
+ assert cur._query is None
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query is not None
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+async def test_binary_cursor_execute(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ await cur.close()
+
+
+async def test_execute_binary(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ await cur.execute("select generate_series(1, 1)")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ await cur.close()
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)", binary=False)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ await cur.close()
+
+
+async def test_close(aconn, recwarn):
+ if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ await aconn.execute("close foo")
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ await cur.close()
+ assert cur.closed
+
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_idempotent(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ await cur.fetchall()
+ await cur.close()
+ await cur.close()
+
+
+async def test_close_broken_conn(aconn):
+ cur = aconn.cursor("foo")
+ await aconn.close()
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_close_noop(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_on_error(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ await aconn.execute("wat")
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+ await cur.close()
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_context(aconn, recwarn):
+ recwarn.clear()
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_no_clobber(aconn):
+ with pytest.raises(e.DivisionByZero):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select 1 / %s", (0,))
+ await cur.fetchall()
+
+
+async def test_warn_close(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+async def test_execute_reuse(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert await cur.fetchone() == (1,)
+
+ await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert await cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+async def test_execute_error(aconn, stmt):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ await cur.execute(stmt)
+ await cur.close()
+
+
+async def test_executemany(aconn):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ await cur.executemany("select %s", [(1,), (2,)])
+ await cur.close()
+
+
+async def test_fetchone(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchone() == (2,)
+ assert await cur.fetchone() is None
+
+
+async def test_fetchmany(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert await cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert await cur.fetchone() == (4,)
+ assert await cur.fetchmany(3) == [(5,)]
+ assert await cur.fetchmany(3) == []
+
+
+async def test_fetchall(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert await cur.fetchall() == []
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchall() == [(2,), (3,)]
+ assert await cur.fetchall() == []
+
+
+async def test_nextset(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+async def test_no_result(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert (await cur.fetchall()) == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_standard_row_factory(aconn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ async with aconn.cursor("foo", row_factory=row_factory) as cur:
+ await cur.execute("select generate_series(1, 5) as bar")
+ assert getter(await cur.fetchone()) == 1
+ assert list(map(getter, await cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, await cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_row_factory(aconn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ await cur.execute("select generate_series(1, 3) as x")
+ recs = await cur.fetchall()
+ await cur.scroll(0, "absolute")
+ while True:
+ rec = await cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ await cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"x": 1}
+ await cur.close()
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ await cur.fetchall()
+ assert cur.rownumber == 42
+ await cur.close()
+
+
+async def test_iter(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(1,), (2,), (3,)]
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(2,), (3,)]
+
+
+async def test_iter_rownumber(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ async for row in cur:
+ assert cur.rownumber == row[0]
+
+
+async def test_itersize(aconn, acommands):
+ async with aconn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ acommands.popall() # flush begin and other noise
+
+ async for rec in cur:
+ pass
+ cmds = acommands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+async def test_cant_scroll_by_default(aconn):
+ cur = aconn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ await cur.scroll(0)
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scroll(aconn):
+ cur = aconn.cursor("tmp", scrollable=True)
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(9, mode="wat")
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ for i in range(4, -1, -1):
+ await curs.scroll(-1)
+ assert i == (await curs.fetchone())[0]
+ await curs.scroll(-1)
+ await curs.close()
+
+
+async def test_non_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ await curs.scroll(-1)
+ await curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+async def test_no_hold(aconn, kwargs):
+ async with aconn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ await curs.execute("select generate_series(0, 2)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ await curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+async def test_hold(aconn):
+ async with aconn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ await curs.execute("select generate_series(0, 5)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ assert await curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+async def test_steal_cursor(aconn, row_factory):
+ cur1 = aconn.cursor()
+ await cur1.execute(
+ "declare test cursor without hold for select generate_series(1, 6) as s"
+ )
+
+ cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = await cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert await cur2.fetchall() == [(5,), (6,)]
+ await cur2.close()
+
+
+async def test_stolen_cursor_close(aconn):
+ cur1 = aconn.cursor()
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()
+
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()
diff --git a/tests/test_sql.py b/tests/test_sql.py
new file mode 100644
index 0000000..42b6c63
--- /dev/null
+++ b/tests/test_sql.py
@@ -0,0 +1,604 @@
+# test_sql.py - tests for the psycopg2.sql module
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import datetime as dt
+
+import pytest
+
+from psycopg import pq, sql, ProgrammingError
+from psycopg.adapt import PyFormat
+from psycopg._encodings import py2pgenc
+from psycopg.types import TypeInfo
+from psycopg.types.string import StrDumper
+
+from .utils import eur
+from .fix_crdb import crdb_encoding, crdb_scs_off
+
+
+@pytest.mark.parametrize(
+ "obj, quoted",
+ [
+ ("foo\\bar", " E'foo\\\\bar'"),
+ ("hello", "'hello'"),
+ (42, "42"),
+ (True, "true"),
+ (None, "NULL"),
+ ],
+)
+def test_quote(obj, quoted):
+ assert sql.quote(obj) == quoted
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_quote_roundtrip(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+
+ for i in range(1, 256):
+ want = chr(i)
+ quoted = sql.quote(want)
+ got = conn.execute(f"select {quoted}::text").fetchone()[0]
+ assert want == got
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages, f"error with {want!r}"
+
+
+@pytest.mark.parametrize("dummy", [crdb_scs_off("off")])
+def test_quote_stable_despite_deranged_libpq(conn, dummy):
+ # Verify the libpq behaviour of PQescapeString using the last setting seen.
+ # Check that we are not affected by it.
+ good_str = " E'\\\\'"
+ good_bytes = " E'\\\\000'::bytea"
+ conn.execute("set standard_conforming_strings to on")
+ assert pq.Escaping().escape_string(b"\\") == b"\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ conn.execute("set standard_conforming_strings to off")
+ assert pq.Escaping().escape_string(b"\\") == b"\\\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ # Verify that the good values are actually good
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute("set escape_string_warning to on")
+ for scs in ("on", "off"):
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ cur = conn.execute(f"select {good_str}, {good_bytes}::bytea")
+ assert cur.fetchone() == ("\\", b"\x00")
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+class TestSqlFormat:
+ def test_pos(self, conn):
+ s = sql.SQL("select {} from {}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_pos_spec(self, conn):
+ s = sql.SQL("select {0} from {1}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ s = sql.SQL("select {1} from {0}").format(
+ sql.Identifier("table"), sql.Identifier("field")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_dict(self, conn):
+ s = sql.SQL("select {f} from {t}").format(
+ f=sql.Identifier("field"), t=sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_compose_literal(self, conn):
+ s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
+ s1 = s.as_string(conn)
+ assert s1 == "select '2016-12-31'::date;"
+
+ def test_compose_empty(self, conn):
+ s = sql.SQL("select foo;").format()
+ s1 = s.as_string(conn)
+ assert s1 == "select foo;"
+
+ def test_percent_escape(self, conn):
+ s = sql.SQL("42 % {0}").format(sql.Literal(7))
+ s1 = s.as_string(conn)
+ assert s1 == "42 % 7"
+
+ def test_braces_escape(self, conn):
+ s = sql.SQL("{{{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{7}"
+ s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{1,7}"
+
+ def test_compose_badnargs(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format()
+
+ def test_compose_badnargs_auto(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {};").format()
+ with pytest.raises(ValueError):
+ sql.SQL("select {} {1};").format(10, 20)
+ with pytest.raises(ValueError):
+ sql.SQL("select {0} {};").format(10, 20)
+
+ def test_compose_bad_args_type(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format(a=10)
+ with pytest.raises(KeyError):
+ sql.SQL("select {x};").format(10)
+
+ def test_no_modifiers(self):
+ with pytest.raises(ValueError):
+ sql.SQL("select {a!r};").format(a=10)
+ with pytest.raises(ValueError):
+ sql.SQL("select {a:<};").format(a=10)
+
+ def test_must_be_adaptable(self, conn):
+ class Foo:
+ pass
+
+ s = sql.SQL("select {0};").format(sql.Literal(Foo()))
+ with pytest.raises(ProgrammingError):
+ s.as_string(conn)
+
+ def test_auto_literal(self, conn):
+ s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1))
+ assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date"
+
+ def test_execute(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.execute(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ (10, "a", "b", "c"),
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall() == [(10, "a", "b", "c")]
+
+ def test_executemany(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.executemany(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ [(10, "a", "b", "c"), (20, "d", "e", "f")],
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")]
+
+ @pytest.mark.crdb_skip("copy")
+ def test_copy(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+
+ with cur.copy(
+ sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ ),
+ ) as copy:
+ copy.write_row((10, "a", "b", "c"))
+ copy.write_row((20, "d", "e", "f"))
+
+ with cur.copy(
+ sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ )
+ ) as copy:
+ assert list(copy) == [b"c\n", b"f\n"]
+
+
+class TestIdentifier:
+ def test_class(self):
+ assert issubclass(sql.Identifier, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
+ with pytest.raises(TypeError):
+ sql.Identifier()
+ with pytest.raises(TypeError):
+ sql.Identifier(10) # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ sql.Identifier(dt.date(2016, 12, 31)) # type: ignore[arg-type]
+
+ def test_repr(self):
+ obj = sql.Identifier("fo'o")
+ assert repr(obj) == 'Identifier("fo\'o")'
+ assert repr(obj) == str(obj)
+
+ obj = sql.Identifier("fo'o", 'ba"r')
+ assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')"
+ assert repr(obj) == str(obj)
+
+ def test_eq(self):
+ assert sql.Identifier("foo") == sql.Identifier("foo")
+ assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar")
+ assert sql.Identifier("foo") != sql.Identifier("bar")
+ assert sql.Identifier("foo") != "foo"
+ assert sql.Identifier("foo") != sql.SQL("foo")
+
+ @pytest.mark.parametrize(
+ "args, want",
+ [
+ (("foo",), '"foo"'),
+ (("foo", "bar"), '"foo"."bar"'),
+ (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+ ],
+ )
+ def test_as_string(self, conn, args, want):
+ assert sql.Identifier(*args).as_string(conn) == want
+
+ @pytest.mark.parametrize(
+ "args, want, enc",
+ [
+ crdb_encoding(("foo",), '"foo"', "ascii"),
+ crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+ crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+ (("foo", eur), f'"foo"."{eur}"', "utf8"),
+ crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+ ],
+ )
+ def test_as_bytes(self, conn, args, want, enc):
+ want = want.encode(enc)
+ conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
+ assert sql.Identifier(*args).as_bytes(conn) == want
+
+ def test_join(self):
+ assert not hasattr(sql.Identifier("foo"), "join")
+
+
+class TestLiteral:
+ def test_class(self):
+ assert issubclass(sql.Literal, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Literal("foo"), sql.Literal)
+ assert isinstance(sql.Literal("foo"), sql.Literal)
+ assert isinstance(sql.Literal(b"foo"), sql.Literal)
+ assert isinstance(sql.Literal(42), sql.Literal)
+ assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal)
+
+ def test_repr(self):
+ assert repr(sql.Literal("foo")) == "Literal('foo')"
+ assert str(sql.Literal("foo")) == "Literal('foo')"
+
+ def test_as_string(self, conn):
+ assert sql.Literal(None).as_string(conn) == "NULL"
+ assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
+ assert sql.Literal(42).as_string(conn) == "42"
+ assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
+
+ def test_as_bytes(self, conn):
+ assert sql.Literal(None).as_bytes(conn) == b"NULL"
+ assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
+ assert sql.Literal(42).as_bytes(conn) == b"42"
+ assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes_encoding(self, conn, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode(encoding)
+
+ def test_eq(self):
+ assert sql.Literal("foo") == sql.Literal("foo")
+ assert sql.Literal("foo") != sql.Literal("bar")
+ assert sql.Literal("foo") != "foo"
+ assert sql.Literal("foo") != sql.SQL("foo")
+
+ def test_must_be_adaptable(self, conn):
+ class Foo:
+ pass
+
+ with pytest.raises(ProgrammingError):
+ sql.Literal(Foo()).as_string(conn)
+
+ def test_array(self, conn):
+ assert (
+ sql.Literal([dt.date(2000, 1, 1)]).as_string(conn)
+ == "'{2000-01-01}'::date[]"
+ )
+
+ def test_short_name_builtin(self, conn):
+ assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time"
+ assert (
+ sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn)
+ == "'2000-01-01 00:00:00'::timestamp"
+ )
+ assert (
+ sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn)
+ == "'{\"2000-01-01 00:00:00\"}'::timestamp[]"
+ )
+
+ def test_text_literal(self, conn):
+ conn.adapters.register_dumper(str, StrDumper)
+ assert sql.Literal("foo").as_string(conn) == "'foo'"
+
+ @pytest.mark.crdb_skip("composite") # create type, actually
+ @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"])
+ def test_invalid_name(self, conn, name):
+ conn.execute(
+ f"""
+ set client_encoding to utf8;
+ create type "{name}";
+ create function invin(cstring) returns "{name}"
+ language internal immutable strict as 'textin';
+ create function invout("{name}") returns cstring
+ language internal immutable strict as 'textout';
+ create type "{name}" (input=invin, output=invout, like=text);
+ """
+ )
+ info = TypeInfo.fetch(conn, f'"{name}"')
+
+ class InvDumper(StrDumper):
+ oid = info.oid
+
+ def dump(self, obj):
+ rv = super().dump(obj)
+ return b"%s-inv" % rv
+
+ info.register(conn)
+ conn.adapters.register_dumper(str, InvDumper)
+
+ assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format("hello"))
+ assert cur.fetchone()[0] == "hello-inv"
+
+ assert (
+ sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]"
+ )
+ cur = conn.execute(sql.SQL("select {}").format(["hello"]))
+ assert cur.fetchone()[0] == ["hello-inv"]
+
+
+class TestSQL:
+ def test_class(self):
+ assert issubclass(sql.SQL, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.SQL("foo"), sql.SQL)
+ assert isinstance(sql.SQL("foo"), sql.SQL)
+ with pytest.raises(TypeError):
+ sql.SQL(10) # type: ignore[arg-type]
+ with pytest.raises(TypeError):
+ sql.SQL(dt.date(2016, 12, 31)) # type: ignore[arg-type]
+
+ def test_repr(self, conn):
+ assert repr(sql.SQL("foo")) == "SQL('foo')"
+ assert str(sql.SQL("foo")) == "SQL('foo')"
+ assert sql.SQL("foo").as_string(conn) == "foo"
+
+ def test_eq(self):
+ assert sql.SQL("foo") == sql.SQL("foo")
+ assert sql.SQL("foo") != sql.SQL("bar")
+ assert sql.SQL("foo") != "foo"
+ assert sql.SQL("foo") != sql.Literal("foo")
+
+ def test_sum(self, conn):
+ obj = sql.SQL("foo") + sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.SQL("f") + sql.SQL("oo")
+ obj += sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_multiply(self, conn):
+ obj = sql.SQL("foo") * 3
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foofoofoo"
+
+ def test_join(self, conn):
+ obj = sql.SQL(", ").join(
+ [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == '"foo", bar, 42'
+
+ obj = sql.SQL(", ").join(
+ sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)])
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == '"foo", bar, 42'
+
+ obj = sql.SQL(", ").join([])
+ assert obj == sql.Composed([])
+
+ def test_as_string(self, conn):
+ assert sql.SQL("foo").as_string(conn) == "foo"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes(self, conn, encoding):
+ if encoding:
+ conn.execute(f"set client_encoding to {encoding}")
+
+ assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
+
+
+class TestComposed:
+ def test_class(self):
+ assert issubclass(sql.Composed, sql.Composable)
+
+ def test_repr(self):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
+ assert str(obj) == repr(obj)
+
+ def test_eq(self):
+ L = [sql.Literal("foo"), sql.Identifier("b'ar")]
+ l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
+ assert sql.Composed(L) == sql.Composed(list(L))
+ assert sql.Composed(L) != L
+ assert sql.Composed(L) != sql.Composed(l2)
+
+ def test_join(self, conn):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ obj = obj.join(", ")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "'foo', \"b'ar\""
+
+ def test_auto_literal(self, conn):
+ obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)])
+ obj = obj.join(", ")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date"
+
+ def test_sum(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj = obj + sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Composed([sql.Literal("bar")])
+ assert isinstance(obj, sql.Composed)
+ assert no_e(obj.as_string(conn)) == "foo 'bar'"
+
+ def test_iter(self):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ it = iter(obj)
+ i = next(it)
+ assert i == sql.SQL("foo")
+ i = next(it)
+ assert i == sql.SQL("bar")
+ with pytest.raises(StopIteration):
+ next(it)
+
+ def test_as_string(self, conn):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ assert obj.as_string(conn) == "foobar"
+
+ def test_as_bytes(self, conn):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ assert obj.as_bytes(conn) == b"foobar"
+
+ @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+ def test_as_bytes_encoding(self, conn, encoding):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)])
+ conn.execute(f"set client_encoding to {encoding}")
+ assert obj.as_bytes(conn) == ("foo" + eur).encode(encoding)
+
+
+class TestPlaceholder:
+ def test_class(self):
+ assert issubclass(sql.Placeholder, sql.Composable)
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_repr_format(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ add = f"format={format.name}" if format != PyFormat.AUTO else ""
+ assert str(ph) == repr(ph) == f"Placeholder({add})"
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_repr_name_format(self, conn, format):
+ ph = sql.Placeholder("foo", format=format)
+ add = f", format={format.name}" if format != PyFormat.AUTO else ""
+ assert str(ph) == repr(ph) == f"Placeholder('foo'{add})"
+
+ def test_bad_name(self):
+ with pytest.raises(ValueError):
+ sql.Placeholder(")")
+
+ def test_eq(self):
+ assert sql.Placeholder("foo") == sql.Placeholder("foo")
+ assert sql.Placeholder("foo") != sql.Placeholder("bar")
+ assert sql.Placeholder("foo") != "foo"
+ assert sql.Placeholder() == sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Literal("foo")
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_as_string(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ assert ph.as_string(conn) == f"%{format.value}"
+
+ ph = sql.Placeholder(name="foo", format=format)
+ assert ph.as_string(conn) == f"%(foo){format.value}"
+
+ @pytest.mark.parametrize("format", PyFormat)
+ def test_as_bytes(self, conn, format):
+ ph = sql.Placeholder(format=format)
+ assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+
+ ph = sql.Placeholder(name="foo", format=format)
+ assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+
+
+class TestValues:
+ def test_null(self, conn):
+ assert isinstance(sql.NULL, sql.SQL)
+ assert sql.NULL.as_string(conn) == "NULL"
+
+ def test_default(self, conn):
+ assert isinstance(sql.DEFAULT, sql.SQL)
+ assert sql.DEFAULT.as_string(conn) == "DEFAULT"
+
+
+def no_e(s):
+ """Drop an eventual E from E'' quotes"""
+ if isinstance(s, memoryview):
+ s = bytes(s)
+
+ if isinstance(s, str):
+ return re.sub(r"\bE'", "'", s)
+ elif isinstance(s, bytes):
+ return re.sub(rb"\bE'", b"'", s)
+ else:
+ raise TypeError(f"not dealing with {type(s).__name__}: {s}")
diff --git a/tests/test_tpc.py b/tests/test_tpc.py
new file mode 100644
index 0000000..91a04e0
--- /dev/null
+++ b/tests/test_tpc.py
@@ -0,0 +1,325 @@
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+pytestmark = pytest.mark.crdb_skip("2-phase commit")
+
+
+def test_tpc_disabled(conn, pipeline):
+ val = int(conn.execute("show max_prepared_transactions").fetchone()[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ conn.rollback()
+ conn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.tpc_prepare()
+
+
+class TestTPC:
+ def test_tpc_commit(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_commit_recovered(self, conn_cls, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ with conn_cls.connect(dsn) as conn:
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_commit(xid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ def test_tpc_rollback(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_one_phase(self, conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_rollback()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_tpc_rollback_recovered(self, conn_cls, conn, dsn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ conn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ with conn_cls.connect(dsn) as conn:
+ xid = conn.xid(1, "gtrid", "bqual")
+ conn.tpc_rollback(xid)
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ def test_status_after_recover(self, conn, tpc):
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ conn.tpc_recover()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+ conn.tpc_recover()
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ def test_recovered_xids(self, conn, tpc):
+ # insert a few test xns
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("begin; prepare transaction '1-foo'")
+ cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (conn.info.dbname,),
+ )
+ okvals = cur.fetchall()
+ okvals.sort()
+
+ xids = conn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == conn.info.dbname]
+ xids.sort(key=lambda x: x.gtrid)
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ def test_xid_encoding(self, conn, tpc):
+ xid = conn.xid(42, "gtrid", "bqual")
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+
+ cur = conn.cursor()
+ cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (conn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ def test_xid_roundtrip(self, conn_cls, conn, dsn, tpc, fid, gtrid, bqual):
+ xid = conn.xid(fid, gtrid, bqual)
+ conn.tpc_begin(xid)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
+
+ assert len(xids) == 1
+ xid = xids[0]
+ conn.tpc_rollback(xid)
+
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ def test_unparsed_roundtrip(self, conn_cls, conn, dsn, tpc, tid):
+ conn.tpc_begin(tid)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
+
+ assert len(xids) == 1
+ xid = xids[0]
+ conn.tpc_rollback(xid)
+
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ def test_xid_unicode(self, conn_cls, conn, dsn, tpc):
+ x1 = conn.xid(10, "uni", "code")
+ conn.tpc_begin(x1)
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ def test_xid_unicode_unparsed(self, conn_cls, conn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ conn.execute("set client_encoding to utf8")
+ conn.commit()
+
+ conn.tpc_begin("transaction-id")
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ def test_cancel_fails_prepared(self, conn, tpc):
+ conn.tpc_begin("cancel")
+ conn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.cancel()
+
+ def test_tpc_recover_non_dbapi_connection(self, conn_cls, conn, dsn, tpc):
+ conn.row_factory = psycopg.rows.dict_row
+ conn.tpc_begin("dict-connection")
+ conn.tpc_prepare()
+ conn.close()
+
+ with conn_cls.connect(dsn) as conn:
+ xids = conn.tpc_recover()
+ xid = [x for x in xids if x.database == conn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None
+
+
+class TestXidObject:
+ def test_xid_construction(self):
+ x1 = psycopg.Xid(74, "foo", "bar")
+ 74 == x1.format_id
+ "foo" == x1.gtrid
+ "bar" == x1.bqual
+
+ def test_xid_from_string(self):
+ x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ 42 == x2.format_id
+ "gtrid" == x2.gtrid
+ "bqual" == x2.bqual
+
+ x3 = psycopg.Xid.from_string("99_xxx_yyy")
+ None is x3.format_id
+ "99_xxx_yyy" == x3.gtrid
+ None is x3.bqual
+
+ def test_xid_to_string(self):
+ x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=")
+ str(x1) == "42_Z3RyaWQ=_YnF1YWw="
+
+ x2 = psycopg.Xid.from_string("99_xxx_yyy")
+ str(x2) == "99_xxx_yyy"
diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py
new file mode 100644
index 0000000..a409a2e
--- /dev/null
+++ b/tests/test_tpc_async.py
@@ -0,0 +1,310 @@
+import pytest
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("2-phase commit"),
+]
+
+
+async def test_tpc_disabled(aconn, apipeline):
+ cur = await aconn.execute("show max_prepared_transactions")
+ val = int((await cur.fetchone())[0])
+ if val:
+ pytest.skip("prepared transactions enabled")
+
+ await aconn.rollback()
+ await aconn.tpc_begin("x")
+ with pytest.raises(psycopg.NotSupportedError):
+ await aconn.tpc_prepare()
+
+
+class TestTPC:
+ async def test_tpc_commit(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_commit()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_commit(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
+
+ async def test_tpc_rollback(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_one_phase(self, aconn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_rollback()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc):
+ xid = aconn.xid(1, "gtrid", "bqual")
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ await aconn.tpc_begin(xid)
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = aconn.cursor()
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ await aconn.tpc_prepare()
+ await aconn.close()
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = aconn.xid(1, "gtrid", "bqual")
+ await aconn.tpc_rollback(xid)
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ async def test_status_after_recover(self, aconn, tpc):
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.IDLE
+
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+ await aconn.tpc_recover()
+ assert aconn.info.transaction_status == TransactionStatus.INTRANS
+
+ async def test_recovered_xids(self, aconn, tpc):
+ # insert a few test xns
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("begin; prepare transaction '1-foo'")
+ await cur.execute("begin; prepare transaction '2-bar'")
+
+ # read the values to return
+ await cur.execute(
+ """
+ select gid, prepared, owner, database from pg_prepared_xacts
+ where database = %s
+ """,
+ (aconn.info.dbname,),
+ )
+ okvals = await cur.fetchall()
+ okvals.sort()
+
+ xids = await aconn.tpc_recover()
+ xids = [xid for xid in xids if xid.database == aconn.info.dbname]
+ xids.sort(key=lambda x: x.gtrid)
+
+ # check the values returned
+ assert len(okvals) == len(xids)
+ for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
+ assert xid.gtrid == gid
+ assert xid.prepared == prepared
+ assert xid.owner == owner
+ assert xid.database == database
+
+ async def test_xid_encoding(self, aconn, tpc):
+ xid = aconn.xid(42, "gtrid", "bqual")
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+
+ cur = aconn.cursor()
+ await cur.execute(
+ "select gid from pg_prepared_xacts where database = %s",
+ (aconn.info.dbname,),
+ )
+ assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0]
+
+ @pytest.mark.parametrize(
+ "fid, gtrid, bqual",
+ [
+ (0, "", ""),
+ (42, "gtrid", "bqual"),
+ (0x7FFFFFFF, "x" * 64, "y" * 64),
+ ],
+ )
+ async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual):
+ xid = aconn.xid(fid, gtrid, bqual)
+ await aconn.tpc_begin(xid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id == fid
+ assert xid.gtrid == gtrid
+ assert xid.bqual == bqual
+
+ @pytest.mark.parametrize(
+ "tid",
+ [
+ "",
+ "hello, world!",
+ "x" * 199, # PostgreSQL's limit in transaction id length
+ ],
+ )
+ async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid):
+ await aconn.tpc_begin(tid)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ]
+ assert len(xids) == 1
+ xid = xids[0]
+ await aconn.tpc_rollback(xid)
+
+ assert xid.format_id is None
+ assert xid.gtrid == tid
+ assert xid.bqual is None
+
+ async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc):
+ x1 = aconn.xid(10, "uni", "code")
+ await aconn.tpc_begin(x1)
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert 10 == xid.format_id
+ assert "uni" == xid.gtrid
+ assert "code" == xid.bqual
+
+ async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc):
+ # We don't expect people shooting snowmen as transaction ids,
+ # so if something explodes in an encode error I don't mind.
+ # Let's just check unicode is accepted as type.
+ await aconn.execute("set client_encoding to utf8")
+ await aconn.commit()
+
+ await aconn.tpc_begin("transaction-id")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xid = [
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
+ ][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "transaction-id"
+ assert xid.bqual is None
+
+ async def test_cancel_fails_prepared(self, aconn, tpc):
+ await aconn.tpc_begin("cancel")
+ await aconn.tpc_prepare()
+ with pytest.raises(psycopg.ProgrammingError):
+ aconn.cancel()
+
+ async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc):
+ aconn.row_factory = psycopg.rows.dict_row
+ await aconn.tpc_begin("dict-connection")
+ await aconn.tpc_prepare()
+ await aconn.close()
+
+ async with await aconn_cls.connect(dsn) as aconn:
+ xids = await aconn.tpc_recover()
+ xid = [x for x in xids if x.database == aconn.info.dbname][0]
+
+ assert xid.format_id is None
+ assert xid.gtrid == "dict-connection"
+ assert xid.bqual is None
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644
index 0000000..9391e00
--- /dev/null
+++ b/tests/test_transaction.py
@@ -0,0 +1,796 @@
+import sys
+import logging
+from threading import Thread, Event
+
+import pytest
+
+import psycopg
+from psycopg import Rollback
+from psycopg import errors as e
+
+# TODOCRDB: is this the expected behaviour?
+crdb_skip_external_observer = pytest.mark.crdb(
+ "skip", reason="deadlock on observer connection"
+)
+
+
+@pytest.fixture
+def conn(conn, pipeline):
+ return conn
+
+
+@pytest.fixture(autouse=True)
+def create_test_table(svcconn):
+ """Creates a table called 'test_table' for use in tests."""
+ cur = svcconn.cursor()
+ cur.execute("drop table if exists test_table")
+ cur.execute("create table test_table (id text primary key)")
+ yield
+ cur.execute("drop table test_table")
+
+
+def insert_row(conn, value):
+ sql = "INSERT INTO test_table VALUES (%s)"
+ if isinstance(conn, psycopg.Connection):
+ conn.cursor().execute(sql, (value,))
+ else:
+
+ async def f():
+ cur = conn.cursor()
+ await cur.execute(sql, (value,))
+
+ return f()
+
+
+def inserted(conn):
+ """Return the values inserted in the test table."""
+ sql = "SELECT * FROM test_table"
+ if isinstance(conn, psycopg.Connection):
+ rows = conn.cursor().execute(sql).fetchall()
+ return set(v for (v,) in rows)
+ else:
+
+ async def f():
+ cur = conn.cursor()
+ await cur.execute(sql)
+ rows = await cur.fetchall()
+ return set(v for (v,) in rows)
+
+ return f()
+
+
+def in_transaction(conn):
+ if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
+ return False
+ elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS:
+ return True
+ else:
+ assert False, conn.pgconn.transaction_status
+
+
+def get_exc_info(exc):
+ """Return the exc info for an exception or a success if exc is None"""
+ if not exc:
+ return (None,) * 3
+ try:
+ raise exc
+ except exc:
+ return sys.exc_info()
+
+
+class ExpectedException(Exception):
+ pass
+
+
+def test_basic(conn, pipeline):
+ """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+ assert not in_transaction(conn)
+ with conn.transaction():
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
+
+
+def test_exposes_associated_connection(conn):
+ """Transaction exposes its connection as a read-only property."""
+ with conn.transaction() as tx:
+ assert tx.connection is conn
+ with pytest.raises(AttributeError):
+ tx.connection = conn
+
+
+def test_exposes_savepoint_name(conn):
+ """Transaction exposes its savepoint name as a read-only property."""
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert tx.savepoint_name == "foo"
+ with pytest.raises(AttributeError):
+ tx.savepoint_name = "bar"
+
+
+def test_cant_reenter(conn):
+ with conn.transaction() as tx:
+ pass
+
+ with pytest.raises(TypeError):
+ with tx:
+ pass
+
+
+def test_begins_on_enter(conn, pipeline):
+ """Transaction does not begin until __enter__() is called."""
+ tx = conn.transaction()
+ assert not in_transaction(conn)
+ with tx:
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
+
+
+def test_commit_on_successful_exit(conn):
+ """Changes are committed on successful exit from the `with` block."""
+ with conn.transaction():
+ insert_row(conn, "foo")
+
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"foo"}
+
+
+def test_rollback_on_exception_exit(conn):
+ """Changes are rolled back if an exception escapes the `with` block."""
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise ExpectedException("This discards the insert")
+
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+def test_context_inerror_rollback_no_clobber(conn_cls, conn, pipeline, dsn, caplog):
+ if pipeline:
+ # Only 'conn' is possibly in pipeline mode, but the transaction and
+ # checks are on 'conn2'.
+ pytest.skip("not applicable")
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ with conn_cls.connect(dsn) as conn2:
+ with conn2.transaction():
+ conn2.execute("select 1")
+ conn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ conn = conn_cls.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ conn.close()
+
+
+def test_interaction_dbapi_transaction(conn):
+ insert_row(conn, "foo")
+
+ with conn.transaction():
+ insert_row(conn, "bar")
+ raise Rollback
+
+ with conn.transaction():
+ insert_row(conn, "baz")
+
+ assert in_transaction(conn)
+ conn.commit()
+ assert inserted(conn) == {"foo", "baz"}
+
+
+def test_prohibits_use_of_commit_rollback_autocommit(conn):
+ """
+ Within a Transaction block, it is forbidden to touch commit, rollback,
+ or the autocommit setting on the connection, as this would interfere
+ with the transaction scope being managed by the Transaction block.
+ """
+ conn.autocommit = False
+ conn.commit()
+ conn.rollback()
+
+ with conn.transaction():
+ with pytest.raises(e.ProgrammingError):
+ conn.autocommit = False
+ with pytest.raises(e.ProgrammingError):
+ conn.commit()
+ with pytest.raises(e.ProgrammingError):
+ conn.rollback()
+
+ conn.autocommit = False
+ conn.commit()
+ conn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+def test_preserves_autocommit(conn, autocommit):
+ """
+ Connection.autocommit is unchanged both during and after Transaction block.
+ """
+ conn.autocommit = autocommit
+ with conn.transaction():
+ assert conn.autocommit is autocommit
+ assert conn.autocommit is autocommit
+
+
+def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are committed
+ """
+ conn.autocommit = False
+ assert not in_transaction(conn)
+ with conn.transaction():
+ insert_row(conn, "new")
+ assert not in_transaction(conn)
+
+ # Changes committed
+ assert inserted(conn) == {"new"}
+ assert inserted(svcconn) == {"new"}
+
+
+def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made within Transaction context are discarded
+ """
+ conn.autocommit = False
+ assert not in_transaction(conn)
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "new")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+
+ # Changes discarded
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_autocommit_off_and_tx_in_progress_successful_exit(conn, pipeline, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are left intact
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ conn.autocommit = False
+ insert_row(conn, "prior")
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ with conn.transaction():
+ insert_row(conn, "new")
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior", "new"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_autocommit_off_and_tx_in_progress_exception_exit(conn, pipeline, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made before the Transaction context are left intact
+ * Changes made within Transaction context are discarded
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ conn.autocommit = False
+ insert_row(conn, "prior")
+ if pipeline:
+ pipeline.sync()
+ assert in_transaction(conn)
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "new")
+ raise ExpectedException()
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
+ """Changes from nested transaction contexts are all persisted on exit."""
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ insert_row(conn, "outer-after")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in outer context escapes.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in inner context escapes the outer context.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
+ """
+ An exception escaping the inner transaction context causes changes made
+ within that inner context to be discarded, but the error can then be
+ handled in the outer context, allowing changes made in the outer context
+ (both before, and after, the inner context) to be successfully committed.
+ """
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with pytest.raises(ExpectedException):
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise ExpectedException()
+ insert_row(conn, "outer-after")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+def test_nested_three_levels_successful_exit(conn, svcconn):
+ """Exercise management of more than one savepoint."""
+ with conn.transaction(): # BEGIN
+ insert_row(conn, "one")
+ with conn.transaction(): # SAVEPOINT s1
+ insert_row(conn, "two")
+ with conn.transaction(): # SAVEPOINT s2
+ insert_row(conn, "three")
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
+
+
+def test_named_savepoint_escapes_savepoint_name(conn):
+ with conn.transaction("s-1"):
+ pass
+ with conn.transaction("s1; drop table students"):
+ pass
+
+
+def test_named_savepoints_successful_exit(conn, commands):
+ """
+ Entering a transaction context will do one of these these things:
+ 1. Begin an outer transaction (if one isn't already in progress)
+ 2. Begin an outer transaction and create a savepoint (if one is named)
+ 3. Create a savepoint (if a transaction is already in progress)
+ either using the name provided, or auto-generating a savepoint name.
+
+ ...and exiting the context successfully will "commit" the same.
+ """
+ # Case 1
+ # Using Transaction explicitly because conn.transaction() enters the contetx
+ assert not commands
+ with conn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 1 (with a transaction already started)
+ conn.cursor().execute("select 1")
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_1"']
+ assert tx.savepoint_name == "_pg3_1"
+ assert commands.popall() == ['RELEASE "_pg3_1"']
+ conn.rollback()
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name provided)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ assert commands.popall() == ['RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ assert commands.popall() == ['RELEASE "_pg3_2"']
+ assert commands.popall() == ["COMMIT"]
+
+
+def test_named_savepoints_exception_exit(conn, commands):
+ """
+ Same as the previous test but checks that when exiting the context with an
+ exception, whatever transaction and/or savepoint was started on enter will
+ be rolled-back as appropriate.
+ """
+ # Case 1
+ with pytest.raises(ExpectedException):
+ with conn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with pytest.raises(ExpectedException):
+ with conn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 3 (with savepoint name provided)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ with conn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ raise ExpectedException
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ with conn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ with conn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ raise ExpectedException
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
+ assert commands.popall() == ["COMMIT"]
+
+
+def test_named_savepoints_with_repeated_names_works(conn):
+ """
+ Using the same savepoint name repeatedly works correctly, but bypasses
+ some sanity checks.
+ """
+ # Works correctly if no inner transactions are rolled back
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("sp"):
+ insert_row(conn, "tx1")
+ with conn.transaction("sp"):
+ insert_row(conn, "tx2")
+ with conn.transaction("sp"):
+ insert_row(conn, "tx3")
+ assert inserted(conn) == {"tx1", "tx2", "tx3"}
+
+ # Works correctly if one level of inner transaction is rolled back
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("s1"):
+ insert_row(conn, "tx1")
+ with conn.transaction("s1", force_rollback=True):
+ insert_row(conn, "tx2")
+ with conn.transaction("s1"):
+ insert_row(conn, "tx3")
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
+
+ # Works correctly if multiple inner transactions are rolled back
+ # (This scenario mandates releasing savepoints after rolling back to them.)
+ with conn.transaction(force_rollback=True):
+ with conn.transaction("s1"):
+ insert_row(conn, "tx1")
+ with conn.transaction("s1") as tx2:
+ insert_row(conn, "tx2")
+ with conn.transaction("s1"):
+ insert_row(conn, "tx3")
+ raise Rollback(tx2)
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
+
+
+def test_force_rollback_successful_exit(conn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with conn.transaction(force_rollback=True):
+ insert_row(conn, "foo")
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+def test_force_rollback_exception_exit(conn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with pytest.raises(ExpectedException):
+ with conn.transaction(force_rollback=True):
+ insert_row(conn, "foo")
+ raise ExpectedException()
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_discards_changes(conn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block exits the block and
+ discards all changes made within that block.
+
+ You can raise any of the following:
+ - Rollback (type)
+ - Rollback() (instance)
+ - Rollback(tx) (instance initialised with reference to the transaction)
+ All of these are equivalent.
+ """
+
+ def assert_no_rows():
+ assert not inserted(conn)
+ assert not inserted(svcconn)
+
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise Rollback
+ assert_no_rows()
+
+ with conn.transaction():
+ insert_row(conn, "foo")
+ raise Rollback()
+ assert_no_rows()
+
+ with conn.transaction() as tx:
+ insert_row(conn, "foo")
+ raise Rollback(tx)
+ assert_no_rows()
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_outer_tx_unaffected(conn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block does not impact an
+ enclosing transaction block.
+ """
+ with conn.transaction():
+ insert_row(conn, "before")
+ with conn.transaction():
+ insert_row(conn, "during")
+ raise Rollback
+ assert in_transaction(conn)
+ assert not inserted(svcconn)
+ insert_row(conn, "after")
+ assert inserted(conn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
+
+
+def test_explicit_rollback_of_outer_transaction(conn):
+ """
+ Raising a Rollback exception that references an outer transaction will
+ discard all changes from both inner and outer transaction blocks.
+ """
+ with conn.transaction() as outer_tx:
+ insert_row(conn, "outer")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise Rollback(outer_tx)
+ assert False, "This line of code should be unreachable."
+ assert not inserted(conn)
+
+
+@crdb_skip_external_observer
+def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
+ """
+ Rolling-back an enclosing transaction does not impact an outer transaction.
+ """
+ with conn.transaction():
+ insert_row(conn, "outer-before")
+ with conn.transaction() as tx_enclosing:
+ insert_row(conn, "enclosing")
+ with conn.transaction():
+ insert_row(conn, "inner")
+ raise Rollback(tx_enclosing)
+ insert_row(conn, "outer-after")
+
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+def test_str(conn, pipeline):
+ with conn.transaction() as tx:
+ if pipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx)
+ else:
+ assert "[INTRANS]" in str(tx)
+ assert "(active)" in str(tx)
+ assert "'" not in str(tx)
+ with conn.transaction("wat") as tx2:
+ if pipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx2)
+ else:
+ assert "[INTRANS]" in str(tx2)
+ assert "'wat'" in str(tx2)
+
+ if pipeline:
+ assert "[IDLE, pipeline=ON]" in str(tx)
+ else:
+ assert "[IDLE]" in str(tx)
+ assert "(terminated)" in str(tx)
+
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction() as tx:
+ 1 / 0
+
+ assert "(terminated)" in str(tx)
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_exit(conn, exit_error):
+ conn.autocommit = True
+
+ t1 = conn.transaction()
+ t1.__enter__()
+
+ t2 = conn.transaction()
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_implicit_begin(conn, exit_error):
+ conn.execute("select 1")
+
+ t1 = conn.transaction()
+ t1.__enter__()
+
+ t2 = conn.transaction()
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_exit_same_name(conn, exit_error):
+ conn.autocommit = True
+
+ t1 = conn.transaction("save")
+ t1.__enter__()
+ t2 = conn.transaction("save")
+ t2.__enter__()
+
+ with pytest.raises(e.ProgrammingError):
+ t1.__exit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ t2.__exit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+def test_concurrency(conn, what):
+ conn.autocommit = True
+
+ evs = [Event() for i in range(3)]
+
+ def worker(unlock, wait_on):
+ with pytest.raises(e.ProgrammingError) as ex:
+ with conn.transaction():
+ unlock.set()
+ wait_on.wait()
+ conn.execute("select 1")
+
+ if what == "error":
+ 1 / 0
+ elif what == "rollback":
+ raise Rollback()
+ else:
+ assert what == "commit"
+
+ if what == "error":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, ZeroDivisionError)
+ elif what == "rollback":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, Rollback)
+ else:
+ assert "transaction commit" in str(ex.value)
+
+ # Start a first transaction in a thread
+ t1 = Thread(target=worker, kwargs={"unlock": evs[0], "wait_on": evs[1]})
+ t1.start()
+ evs[0].wait()
+
+ # Start a nested transaction in a thread
+ t2 = Thread(target=worker, kwargs={"unlock": evs[1], "wait_on": evs[2]})
+ t2.start()
+
+ # Terminate the first transaction before the second does
+ t1.join()
+ evs[2].set()
+ t2.join()
diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py
new file mode 100644
index 0000000..55e1c9c
--- /dev/null
+++ b/tests/test_transaction_async.py
@@ -0,0 +1,743 @@
+import asyncio
+import logging
+
+import pytest
+
+from psycopg import Rollback
+from psycopg import errors as e
+from psycopg._compat import create_task
+
+from .test_transaction import in_transaction, insert_row, inserted, get_exc_info
+from .test_transaction import ExpectedException, crdb_skip_external_observer
+from .test_transaction import create_test_table # noqa # autouse fixture
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def aconn(aconn, apipeline):
+ return aconn
+
+
+async def test_basic(aconn, apipeline):
+ """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_exposes_associated_connection(aconn):
+ """Transaction exposes its connection as a read-only property."""
+ async with aconn.transaction() as tx:
+ assert tx.connection is aconn
+ with pytest.raises(AttributeError):
+ tx.connection = aconn
+
+
+async def test_exposes_savepoint_name(aconn):
+ """Transaction exposes its savepoint name as a read-only property."""
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert tx.savepoint_name == "foo"
+ with pytest.raises(AttributeError):
+ tx.savepoint_name = "bar"
+
+
+async def test_cant_reenter(aconn):
+ async with aconn.transaction() as tx:
+ pass
+
+ with pytest.raises(TypeError):
+ async with tx:
+ pass
+
+
+async def test_begins_on_enter(aconn, apipeline):
+ """Transaction does not begin until __enter__() is called."""
+ tx = aconn.transaction()
+ assert not in_transaction(aconn)
+ async with tx:
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_commit_on_successful_exit(aconn):
+ """Changes are committed on successful exit from the `with` block."""
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"foo"}
+
+
+async def test_rollback_on_exception_exit(aconn):
+ """Changes are rolled back if an exception escapes the `with` block."""
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise ExpectedException("This discards the insert")
+
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+
+
+@pytest.mark.crdb_skip("pg_terminate_backend")
+async def test_context_inerror_rollback_no_clobber(
+ aconn_cls, aconn, apipeline, dsn, caplog
+):
+ if apipeline:
+ # Only 'aconn' is possibly in pipeline mode, but the transaction and
+ # checks are on 'conn2'.
+ pytest.skip("not applicable")
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ with pytest.raises(ZeroDivisionError):
+ async with await aconn_cls.connect(dsn) as conn2:
+ async with conn2.transaction():
+ await conn2.execute("select 1")
+ await aconn.execute(
+ "select pg_terminate_backend(%s::int)",
+ [conn2.pgconn.backend_pid],
+ )
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+
+
+@pytest.mark.crdb_skip("copy")
+async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+
+ conn = await aconn_cls.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ async with conn.transaction():
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
+
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ await conn.close()
+
+
+async def test_interaction_dbapi_transaction(aconn):
+ await insert_row(aconn, "foo")
+
+ async with aconn.transaction():
+ await insert_row(aconn, "bar")
+ raise Rollback
+
+ async with aconn.transaction():
+ await insert_row(aconn, "baz")
+
+ assert in_transaction(aconn)
+ await aconn.commit()
+ assert await inserted(aconn) == {"foo", "baz"}
+
+
+async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
+ """
+ Within a Transaction block, it is forbidden to touch commit, rollback,
+ or the autocommit setting on the connection, as this would interfere
+ with the transaction scope being managed by the Transaction block.
+ """
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+ async with aconn.transaction():
+ with pytest.raises(e.ProgrammingError):
+ await aconn.set_autocommit(False)
+ with pytest.raises(e.ProgrammingError):
+ await aconn.commit()
+ with pytest.raises(e.ProgrammingError):
+ await aconn.rollback()
+
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+async def test_preserves_autocommit(aconn, autocommit):
+ """
+ Connection.autocommit is unchanged both during and after Transaction block.
+ """
+ await aconn.set_autocommit(autocommit)
+ async with aconn.transaction():
+ assert aconn.autocommit is autocommit
+ assert aconn.autocommit is autocommit
+
+
+async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are committed
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert not in_transaction(aconn)
+
+ # Changes committed
+ assert await inserted(aconn) == {"new"}
+ assert inserted(svcconn) == {"new"}
+
+
+async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made within Transaction context are discarded
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+
+ # Changes discarded
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_autocommit_off_and_tx_in_progress_successful_exit(
+ aconn, apipeline, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are left intact
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior", "new"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_autocommit_off_and_tx_in_progress_exception_exit(
+ aconn, apipeline, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made before the Transaction context are left intact
+ * Changes made within Transaction context are discarded
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ if apipeline:
+ await apipeline.sync()
+ assert in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn):
+ """Changes from nested transaction contexts are all persisted on exit."""
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in outer context escapes.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in inner context escapes the outer context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn):
+ """
+ An exception escaping the inner transaction context causes changes made
+ within that inner context to be discarded, but the error can then be
+ handled in the outer context, allowing changes made in the outer context
+ (both before, and after, the inner context) to be successfully committed.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_nested_three_levels_successful_exit(aconn, svcconn):
+ """Exercise management of more than one savepoint."""
+ async with aconn.transaction(): # BEGIN
+ await insert_row(aconn, "one")
+ async with aconn.transaction(): # SAVEPOINT s1
+ await insert_row(aconn, "two")
+ async with aconn.transaction(): # SAVEPOINT s2
+ await insert_row(aconn, "three")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
+
+
+async def test_named_savepoint_escapes_savepoint_name(aconn):
+ async with aconn.transaction("s-1"):
+ pass
+ async with aconn.transaction("s1; drop table students"):
+ pass
+
+
+async def test_named_savepoints_successful_exit(aconn, acommands):
+ """
+ Entering a transaction context will do one of these these things:
+ 1. Begin an outer transaction (if one isn't already in progress)
+ 2. Begin an outer transaction and create a savepoint (if one is named)
+ 3. Create a savepoint (if a transaction is already in progress)
+ either using the name provided, or auto-generating a savepoint name.
+
+ ...and exiting the context successfully will "commit" the same.
+ """
+ commands = acommands
+
+ # Case 1
+ # Using Transaction explicitly because conn.transaction() enters the contetx
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 1 (with a transaction already started)
+ await aconn.cursor().execute("select 1")
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_1"']
+ assert tx.savepoint_name == "_pg3_1"
+
+ assert commands.popall() == ['RELEASE "_pg3_1"']
+ await aconn.rollback()
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ assert commands.popall() == ['RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ assert commands.popall() == ['RELEASE "_pg3_2"']
+ assert commands.popall() == ["COMMIT"]
+
+
+async def test_named_savepoints_exception_exit(aconn, acommands):
+ """
+ Same as the previous test but checks that when exiting the context with an
+ exception, whatever transaction and/or savepoint was started on enter will
+ be rolled-back as appropriate.
+ """
+ commands = acommands
+
+ # Case 1
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ["BEGIN"]
+ assert not tx.savepoint_name
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 2
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
+ assert tx.savepoint_name == "foo"
+ raise ExpectedException
+ assert commands.popall() == ["ROLLBACK"]
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(savepoint_name="bar") as tx:
+ assert commands.popall() == ['SAVEPOINT "bar"']
+ assert tx.savepoint_name == "bar"
+ raise ExpectedException
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
+ assert commands.popall() == ["COMMIT"]
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.popall() == ["BEGIN"]
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction() as tx:
+ assert commands.popall() == ['SAVEPOINT "_pg3_2"']
+ assert tx.savepoint_name == "_pg3_2"
+ raise ExpectedException
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
+ assert commands.popall() == ["COMMIT"]
+
+
+async def test_named_savepoints_with_repeated_names_works(aconn):
+ """
+ Using the same savepoint name repeatedly works correctly, but bypasses
+ some sanity checks.
+ """
+ # Works correctly if no inner transactions are rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1", "tx2", "tx3"}
+
+ # Works correctly if one level of inner transaction is rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1", force_rollback=True):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+ # Works correctly if multiple inner transactions are rolled back
+ # (This scenario mandates releasing savepoints after rolling back to them.)
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1") as tx2:
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ raise Rollback(tx2)
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+
+async def test_force_rollback_successful_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_force_rollback_exception_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ raise ExpectedException()
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_discards_changes(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block exits the block and
+ discards all changes made within that block.
+
+ You can raise any of the following:
+ - Rollback (type)
+ - Rollback() (instance)
+ - Rollback(tx) (instance initialised with reference to the transaction)
+ All of these are equivalent.
+ """
+
+ async def assert_no_rows():
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback
+ await assert_no_rows()
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback()
+ await assert_no_rows()
+
+ async with aconn.transaction() as tx:
+ await insert_row(aconn, "foo")
+ raise Rollback(tx)
+ await assert_no_rows()
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block does not impact an
+ enclosing transaction block.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "before")
+ async with aconn.transaction():
+ await insert_row(aconn, "during")
+ raise Rollback
+ assert in_transaction(aconn)
+ assert not inserted(svcconn)
+ await insert_row(aconn, "after")
+ assert await inserted(aconn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
+
+
+async def test_explicit_rollback_of_outer_transaction(aconn):
+ """
+ Raising a Rollback exception that references an outer transaction will
+ discard all changes from both inner and outer transaction blocks.
+ """
+ async with aconn.transaction() as outer_tx:
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(outer_tx)
+ assert False, "This line of code should be unreachable."
+ assert not await inserted(aconn)
+
+
+@crdb_skip_external_observer
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn):
+ """
+ Rolling-back an enclosing transaction does not impact an outer transaction.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction() as tx_enclosing:
+ await insert_row(aconn, "enclosing")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(tx_enclosing)
+ await insert_row(aconn, "outer-after")
+
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_str(aconn, apipeline):
+ async with aconn.transaction() as tx:
+ if apipeline:
+ assert "[INTRANS]" not in str(tx)
+ await apipeline.sync()
+ assert "[INTRANS, pipeline=ON]" in str(tx)
+ else:
+ assert "[INTRANS]" in str(tx)
+ assert "(active)" in str(tx)
+ assert "'" not in str(tx)
+ async with aconn.transaction("wat") as tx2:
+ if apipeline:
+ assert "[INTRANS, pipeline=ON]" in str(tx2)
+ else:
+ assert "[INTRANS]" in str(tx2)
+ assert "'wat'" in str(tx2)
+
+ if apipeline:
+ assert "[IDLE, pipeline=ON]" in str(tx)
+ else:
+ assert "[IDLE]" in str(tx)
+ assert "(terminated)" in str(tx)
+
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.transaction() as tx:
+ 1 / 0
+
+ assert "(terminated)" in str(tx)
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_exit(aconn, exit_error):
+ await aconn.set_autocommit(True)
+
+ t1 = aconn.transaction()
+ await t1.__aenter__()
+
+ t2 = aconn.transaction()
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_implicit_begin(aconn, exit_error):
+ await aconn.execute("select 1")
+
+ t1 = aconn.transaction()
+ await t1.__aenter__()
+
+ t2 = aconn.transaction()
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_exit_same_name(aconn, exit_error):
+ await aconn.set_autocommit(True)
+
+ t1 = aconn.transaction("save")
+ await t1.__aenter__()
+ t2 = aconn.transaction("save")
+ await t2.__aenter__()
+
+ with pytest.raises(e.ProgrammingError):
+ await t1.__aexit__(*get_exc_info(exit_error))
+
+ with pytest.raises(e.ProgrammingError):
+ await t2.__aexit__(*get_exc_info(exit_error))
+
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+async def test_concurrency(aconn, what):
+ await aconn.set_autocommit(True)
+
+ evs = [asyncio.Event() for i in range(3)]
+
+ async def worker(unlock, wait_on):
+ with pytest.raises(e.ProgrammingError) as ex:
+ async with aconn.transaction():
+ unlock.set()
+ await wait_on.wait()
+ await aconn.execute("select 1")
+
+ if what == "error":
+ 1 / 0
+ elif what == "rollback":
+ raise Rollback()
+ else:
+ assert what == "commit"
+
+ if what == "error":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, ZeroDivisionError)
+ elif what == "rollback":
+ assert "transaction rollback" in str(ex.value)
+ assert isinstance(ex.value.__context__, Rollback)
+ else:
+ assert "transaction commit" in str(ex.value)
+
+ # Start a first transaction in a task
+ t1 = create_task(worker(unlock=evs[0], wait_on=evs[1]))
+ await evs[0].wait()
+
+ # Start a nested transaction in a task
+ t2 = create_task(worker(unlock=evs[1], wait_on=evs[2]))
+
+ # Terminate the first transaction before the second does
+ await asyncio.gather(t1)
+ evs[2].set()
+ await asyncio.gather(t2)
diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py
new file mode 100644
index 0000000..d0e57e6
--- /dev/null
+++ b/tests/test_typeinfo.py
@@ -0,0 +1,145 @@
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.pq import TransactionStatus
+from psycopg.types import TypeInfo
+
+
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch(conn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ conn.execute("select 1")
+
+ assert conn.info.transaction_status == status
+ info = TypeInfo.fetch(conn, name)
+ assert conn.info.transaction_status == status
+
+ assert info.name == "text"
+ # TODO: add the schema?
+ # assert info.schema == "pg_catalog"
+
+ assert info.oid == psycopg.adapters.types["text"].oid
+ assert info.array_oid == psycopg.adapters.types["text"].array_oid
+ assert info.regtype == "text"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_async(aconn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ await aconn.execute("select 1")
+
+ assert aconn.info.transaction_status == status
+ info = await TypeInfo.fetch(aconn, name)
+ assert aconn.info.transaction_status == status
+
+ assert info.name == "text"
+ # assert info.schema == "pg_catalog"
+ assert info.oid == psycopg.adapters.types["text"].oid
+ assert info.array_oid == psycopg.adapters.types["text"].array_oid
+
+
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+def test_fetch_not_found(conn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ conn.execute("select 1")
+
+ assert conn.info.transaction_status == status
+ info = TypeInfo.fetch(conn, name)
+ assert conn.info.transaction_status == status
+ assert info is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+async def test_fetch_not_found_async(aconn, name, status):
+ status = getattr(TransactionStatus, status)
+ if status == TransactionStatus.INTRANS:
+ await aconn.execute("select 1")
+
+ assert aconn.info.transaction_status == status
+ info = await TypeInfo.fetch(aconn, name)
+ assert aconn.info.transaction_status == status
+
+ assert info is None
+
+
+@pytest.mark.crdb_skip("composite")
+@pytest.mark.parametrize(
+ "name", ["testschema.testtype", sql.Identifier("testschema", "testtype")]
+)
+def test_fetch_by_schema_qualified_string(conn, name):
+ conn.execute("create schema if not exists testschema")
+ conn.execute("create type testschema.testtype as (foo text)")
+
+ info = TypeInfo.fetch(conn, name)
+ assert info.name == "testtype"
+ # assert info.schema == "testschema"
+ cur = conn.execute(
+ """
+ select oid, typarray from pg_type
+ where oid = 'testschema.testtype'::regtype
+ """
+ )
+ assert cur.fetchone() == (info.oid, info.array_oid)
+
+
+@pytest.mark.parametrize(
+ "name",
+ [
+ "text",
+ # TODO: support these?
+ # "pg_catalog.text",
+ # sql.Identifier("text"),
+ # sql.Identifier("pg_catalog", "text"),
+ ],
+)
+def test_registry_by_builtin_name(conn, name):
+ info = psycopg.adapters.types[name]
+ assert info.name == "text"
+ assert info.oid == 25
+
+
+def test_registry_empty():
+ r = psycopg.types.TypesRegistry()
+ assert r.get("text") is None
+ with pytest.raises(KeyError):
+ r["text"]
+
+
+@pytest.mark.parametrize("oid, aoid", [(1, 2), (1, 0), (0, 2), (0, 0)])
+def test_registry_invalid_oid(oid, aoid):
+ r = psycopg.types.TypesRegistry()
+ ti = psycopg.types.TypeInfo("test", oid, aoid)
+ r.add(ti)
+ assert r["test"] is ti
+ if oid:
+ assert r[oid] is ti
+ if aoid:
+ assert r[aoid] is ti
+ with pytest.raises(KeyError):
+ r[0]
+
+
+def test_registry_copy():
+ r = psycopg.types.TypesRegistry(psycopg.postgres.types)
+ assert r.get("text") is r["text"] is r[25]
+ assert r["text"].oid == 25
+
+
+def test_registry_isolated():
+ orig = psycopg.postgres.types
+ tinfo = orig["text"]
+ r = psycopg.types.TypesRegistry(orig)
+ tdummy = psycopg.types.TypeInfo("dummy", tinfo.oid, tinfo.array_oid)
+ r.add(tdummy)
+ assert r[25] is r["dummy"] is tdummy
+ assert orig[25] is r["text"] is tinfo
diff --git a/tests/test_typing.py b/tests/test_typing.py
new file mode 100644
index 0000000..fff9cec
--- /dev/null
+++ b/tests/test_typing.py
@@ -0,0 +1,449 @@
+import os
+
+import pytest
+
+HERE = os.path.dirname(os.path.abspath(__file__))
+
+
+@pytest.mark.parametrize(
+ "filename",
+ ["adapters_example.py", "typing_example.py"],
+)
+def test_typing_example(mypy, filename):
+ cp = mypy.run_on_file(os.path.join(HERE, filename))
+ errors = cp.stdout.decode("utf8", "replace").splitlines()
+ assert not errors
+ assert cp.returncode == 0
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.connect()",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.tuple_row)",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Connection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.namedtuple_row)",
+ "psycopg.Connection[NamedTuple]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.class_row(Thing))",
+ "psycopg.Connection[Thing]",
+ ),
+ (
+ "psycopg.connect(row_factory=thing_row)",
+ "psycopg.Connection[Thing]",
+ ),
+ (
+ "psycopg.Connection.connect()",
+ "psycopg.Connection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.Connection.connect(row_factory=rows.dict_row)",
+ "psycopg.Connection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "conn, curs, type",
+ [
+ (
+ "psycopg.connect()",
+ "conn.cursor()",
+ "psycopg.Cursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor()",
+ "psycopg.Cursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor(row_factory=rows.namedtuple_row)",
+ "psycopg.Cursor[NamedTuple]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.class_row(Thing))",
+ "conn.cursor()",
+ "psycopg.Cursor[Thing]",
+ ),
+ (
+ "psycopg.connect(row_factory=thing_row)",
+ "conn.cursor()",
+ "psycopg.Cursor[Thing]",
+ ),
+ (
+ "psycopg.connect()",
+ "conn.cursor(row_factory=thing_row)",
+ "psycopg.Cursor[Thing]",
+ ),
+ # Async cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor()",
+ "psycopg.AsyncCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(row_factory=thing_row)",
+ "psycopg.AsyncCursor[Thing]",
+ ),
+ # Server-side cursors
+ (
+ "psycopg.connect()",
+ "conn.cursor(name='foo')",
+ "psycopg.ServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "conn.cursor(name='foo')",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect()",
+ "conn.cursor(name='foo', row_factory=rows.dict_row)",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ # Async server-side cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(name='foo')",
+ "psycopg.AsyncServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "conn.cursor(name='foo')",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "conn.cursor(name='foo', row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_cursor_type(conn, curs, type, mypy):
+ stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "conn, curs, type",
+ [
+ (
+ "psycopg.connect()",
+ "psycopg.Cursor(conn)",
+ "psycopg.Cursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Cursor(conn)",
+ "psycopg.Cursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)",
+ "psycopg.Cursor[NamedTuple]",
+ ),
+ # Async cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncCursor(conn)",
+ "psycopg.AsyncCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncCursor(conn)",
+ "psycopg.AsyncCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncCursor(conn, row_factory=thing_row)",
+ "psycopg.AsyncCursor[Thing]",
+ ),
+ # Server-side cursors
+ (
+ "psycopg.connect()",
+ "psycopg.ServerCursor(conn, 'foo')",
+ "psycopg.ServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.ServerCursor(conn, name='foo')",
+ "psycopg.ServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.connect(row_factory=rows.dict_row)",
+ "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)",
+ "psycopg.ServerCursor[NamedTuple]",
+ ),
+ # Async server-side cursors
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncServerCursor(conn, name='foo')",
+ "psycopg.AsyncServerCursor[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor(conn, name='foo')",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.AsyncConnection.connect()",
+ "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)",
+ "psycopg.AsyncServerCursor[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_cursor_type_init(conn, curs, type, mypy):
+ stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "Optional[Tuple[Any, ...]]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "Optional[Dict[str, Any]]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "Optional[Thing]",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchone_type(conn_class, server_side, curs, type, mypy):
+ await_ = "await" if "Async" in conn_class else ""
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_} psycopg.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.fetchone()
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "Tuple[Any, ...]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "Dict[str, Any]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "Thing",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_iter_type(conn_class, server_side, curs, type, mypy):
+ if "Async" in conn_class:
+ async_ = "async "
+ await_ = "await "
+ else:
+ async_ = await_ = ""
+
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_}psycopg.{conn_class}.connect()
+curs = {curs}
+{async_}for obj in curs:
+ pass
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
+@pytest.mark.parametrize(
+ "curs, type",
+ [
+ (
+ "conn.cursor()",
+ "List[Tuple[Any, ...]]",
+ ),
+ (
+ "conn.cursor(row_factory=rows.dict_row)",
+ "List[Dict[str, Any]]",
+ ),
+ (
+ "conn.cursor(row_factory=thing_row)",
+ "List[Thing]",
+ ),
+ ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
+ await_ = "await" if "Async" in conn_class else ""
+ if server_side:
+ curs = curs.replace("(", "(name='foo',", 1)
+ stmts = f"""\
+conn = {await_} psycopg.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.{method}()
+"""
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_cur_subclass_execute(mypy, conn_class, server_side):
+ async_ = "async " if "Async" in conn_class else ""
+ await_ = "await" if "Async" in conn_class else ""
+ cur_base_class = "".join(
+ [
+ "Async" if "Async" in conn_class else "",
+ "Server" if server_side else "",
+ "Cursor",
+ ]
+ )
+ cur_name = "'foo'" if server_side else ""
+
+ src = f"""\
+from typing import Any, cast
+import psycopg
+from psycopg.rows import Row, TupleRow
+
+class MyCursor(psycopg.{cur_base_class}[Row]):
+ pass
+
+{async_}def test() -> None:
+ conn = {await_} psycopg.{conn_class}.connect()
+
+ cur: MyCursor[TupleRow]
+ reveal_type(cur)
+
+ cur = cast(MyCursor[TupleRow], conn.cursor({cur_name}))
+ {async_}with cur as cur2:
+ reveal_type(cur2)
+ cur3 = {await_} cur2.execute("")
+ reveal_type(cur3)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 3
+ types = [mypy.get_revealed(line) for line in out]
+ assert types[0] == types[1]
+ assert types[0] == types[2]
+
+
+def _test_reveal(stmts, type, mypy):
+ ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]"
+ stmts = "\n".join(f" {line}" for line in stmts.splitlines())
+
+ src = f"""\
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
+from typing import Tuple, Union
+import psycopg
+from psycopg import rows
+
+class Thing:
+ def __init__(self, **kwargs: Any) -> None:
+ self.kwargs = kwargs
+
+def thing_row(
+ cur: Union[psycopg.Cursor[Any], psycopg.AsyncCursor[Any]],
+) -> Callable[[Sequence[Any]], Thing]:
+ assert cur.description
+ names = [d.name for d in cur.description]
+
+ def make_row(t: Sequence[Any]) -> Thing:
+ return Thing(**dict(zip(names, t)))
+
+ return make_row
+
+async def tmp() -> None:
+{stmts}
+ reveal_type(obj)
+
+ref: {type} = None {ignore}
+reveal_type(ref)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 2, "\n".join(out)
+ got, want = [mypy.get_revealed(line) for line in out]
+ assert got == want
+
+
+@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308")
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "MyConnection.connect()",
+ "MyConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "MyConnection.connect(row_factory=rows.tuple_row)",
+ "MyConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "MyConnection.connect(row_factory=rows.dict_row)",
+ "MyConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_generic_connect(conn, type, mypy):
+ src = f"""
+from typing import Any, Dict, Tuple
+import psycopg
+from psycopg import rows
+
+class MyConnection(psycopg.Connection[rows.Row]):
+ pass
+
+obj = {conn}
+reveal_type(obj)
+
+ref: {type} = None # type: ignore[assignment]
+reveal_type(ref)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 2, "\n".join(out)
+ got, want = [mypy.get_revealed(line) for line in out]
+ assert got == want
diff --git a/tests/test_waiting.py b/tests/test_waiting.py
new file mode 100644
index 0000000..63237e8
--- /dev/null
+++ b/tests/test_waiting.py
@@ -0,0 +1,159 @@
+import select # noqa: used in pytest.mark.skipif
+import socket
+import sys
+
+import pytest
+
+import psycopg
+from psycopg import waiting
+from psycopg import generators
+from psycopg.pq import ConnStatus, ExecStatus
+
+skip_if_not_linux = pytest.mark.skipif(
+ not sys.platform.startswith("linux"), reason="non-Linux platform"
+)
+
+waitfns = [
+ "wait",
+ "wait_selector",
+ pytest.param(
+ "wait_select", marks=pytest.mark.skipif("not hasattr(select, 'select')")
+ ),
+ pytest.param(
+ "wait_epoll", marks=pytest.mark.skipif("not hasattr(select, 'epoll')")
+ ),
+ pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")),
+]
+
+timeouts = [pytest.param({}, id="blank")]
+timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]]
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait_conn(dsn, timeout):
+ gen = generators.connect(dsn)
+ conn = waiting.wait_conn(gen, **timeout)
+ assert conn.status == ConnStatus.OK
+
+
+def test_wait_conn_bad(dsn):
+ gen = generators.connect("dbname=nosuchdb")
+ with pytest.raises(psycopg.OperationalError):
+ waiting.wait_conn(gen)
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@skip_if_not_linux
+def test_wait_ready(waitfn, wait, ready):
+ waitfn = getattr(waiting, waitfn)
+
+ def gen():
+ r = yield wait
+ return r
+
+ with socket.socket() as s:
+ r = waitfn(gen(), s.fileno())
+ assert r & ready
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait(pgconn, waitfn, timeout):
+ waitfn = getattr(waiting, waitfn)
+
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ (res,) = waitfn(gen, pgconn.socket, **timeout)
+ assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_bad(pgconn, waitfn):
+ waitfn = getattr(waiting, waitfn)
+
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ waitfn(gen, pgconn.socket)
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(
+ "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious"
+)
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_large_fd(dsn, waitfn):
+ waitfn = getattr(waiting, waitfn)
+
+ files = []
+ try:
+ try:
+ for i in range(1100):
+ files.append(open(__file__))
+ except OSError:
+ pytest.skip("can't open the number of files needed for the test")
+
+ pgconn = psycopg.pq.PGconn.connect(dsn.encode())
+ try:
+ assert pgconn.socket > 1024
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ if waitfn is waiting.wait_select:
+ with pytest.raises(ValueError):
+ waitfn(gen, pgconn.socket)
+ else:
+ (res,) = waitfn(gen, pgconn.socket)
+ assert res.status == ExecStatus.TUPLES_OK
+ finally:
+ pgconn.finish()
+ finally:
+ for f in files:
+ f.close()
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+@pytest.mark.asyncio
+async def test_wait_conn_async(dsn, timeout):
+ gen = generators.connect(dsn)
+ conn = await waiting.wait_conn_async(gen, **timeout)
+ assert conn.status == ConnStatus.OK
+
+
+@pytest.mark.asyncio
+async def test_wait_conn_async_bad(dsn):
+ gen = generators.connect("dbname=nosuchdb")
+ with pytest.raises(psycopg.OperationalError):
+ await waiting.wait_conn_async(gen)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@skip_if_not_linux
+async def test_wait_ready_async(wait, ready):
+ def gen():
+ r = yield wait
+ return r
+
+ with socket.socket() as s:
+ r = await waiting.wait_async(gen(), s.fileno())
+ assert r & ready
+
+
+@pytest.mark.asyncio
+async def test_wait_async(pgconn):
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ (res,) = await waiting.wait_async(gen, pgconn.socket)
+ assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.asyncio
+async def test_wait_async_bad(pgconn):
+ pgconn.send_query(b"select 1")
+ gen = generators.execute(pgconn)
+ socket = pgconn.socket
+ pgconn.finish()
+ with pytest.raises(psycopg.OperationalError):
+ await waiting.wait_async(gen, socket)
diff --git a/tests/test_windows.py b/tests/test_windows.py
new file mode 100644
index 0000000..09e61ba
--- /dev/null
+++ b/tests/test_windows.py
@@ -0,0 +1,23 @@
+import pytest
+import asyncio
+import sys
+
+from psycopg.errors import InterfaceError
+
+
+@pytest.mark.skipif(sys.platform != "win32", reason="windows only test")
+def test_windows_error(aconn_cls, dsn):
+ loop = asyncio.ProactorEventLoop() # type: ignore[attr-defined]
+
+ async def go():
+ with pytest.raises(
+ InterfaceError,
+ match="Psycopg cannot use the 'ProactorEventLoop'",
+ ):
+ await aconn_cls.connect(dsn)
+
+ try:
+ loop.run_until_complete(go())
+ finally:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ loop.close()
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/types/__init__.py
diff --git a/tests/types/test_array.py b/tests/types/test_array.py
new file mode 100644
index 0000000..74c17a6
--- /dev/null
+++ b/tests/types/test_array.py
@@ -0,0 +1,338 @@
+from typing import List, Any
+from decimal import Decimal
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat, Transformer, Dumper
+from psycopg.types import TypeInfo
+from psycopg._compat import prod
+from psycopg.postgres import types as builtins
+
+
+tests_str = [
+ ([[[[[["a"]]]]]], "{{{{{{a}}}}}}"),
+ ([[[[[[None]]]]]], "{{{{{{NULL}}}}}}"),
+ ([[[[[["NULL"]]]]]], '{{{{{{"NULL"}}}}}}'),
+ (["foo", "bar", "baz"], "{foo,bar,baz}"),
+ (["foo", None, "baz"], "{foo,null,baz}"),
+ (["foo", "null", "", "baz"], '{foo,"null","",baz}'),
+ (
+ [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]],
+ "{{foo,bar},{baz,qux},{quux,quuux}}",
+ ),
+ (
+ [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]],
+ r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}',
+ ),
+]
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("type", ["text", "int4"])
+def test_dump_empty_list(conn, fmt_in, type):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}::{type}[] = %s::{type}[]", ([], "{}"))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("obj, want", tests_str)
+def test_dump_list_str(conn, obj, want, fmt_in):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}::text[] = %s::text[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_empty_list_str(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::text[]", ([],))
+ assert cur.fetchone()[0] == []
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("want, obj", tests_str)
+def test_load_list_str(conn, obj, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::text[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_all_chars(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ c = chr(i)
+ cur.execute(f"select %{fmt_in.value}::text[]", ([c],))
+ assert cur.fetchone()[0] == [c]
+
+ a = list(map(chr, range(1, 256)))
+ a.append("\u20ac")
+ cur.execute(f"select %{fmt_in.value}::text[]", (a,))
+ assert cur.fetchone()[0] == a
+
+ s = "".join(a)
+ cur.execute(f"select %{fmt_in.value}::text[]", ([s],))
+ assert cur.fetchone()[0] == [s]
+
+
+tests_int = [
+ ([10, 20, -30], "{10,20,-30}"),
+ ([10, None, 30], "{10,null,30}"),
+ ([[10, 20], [30, 40]], "{{10,20},{30,40}}"),
+]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("obj, want", tests_int)
+def test_dump_list_int(conn, obj, want):
+ cur = conn.cursor()
+ cur.execute("select %s::int[] = %s::int[]", (obj, want))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ [["a"], ["b", "c"]],
+ [["a"], []],
+ [[["a"]], ["b"]],
+ # [["a"], [["b"]]], # todo, but expensive (an isinstance per item)
+ # [True, b"a"], # TODO expensive too
+ ],
+)
+def test_bad_binary_array(input):
+ tx = Transformer()
+ with pytest.raises(psycopg.DataError):
+ tx.get_dumper(input, PyFormat.BINARY).dump(input)
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("want, obj", tests_int)
+def test_load_list_int(conn, obj, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::int[]", (obj,))
+ assert cur.fetchone()[0] == want
+
+ stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format(
+ obj, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["int4[]"])
+ (got,) = copy.read_row()
+
+ assert got == want
+
+
+@pytest.mark.crdb_skip("composite")
+def test_array_register(conn):
+ conn.execute("create table mytype (data text)")
+ cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+ res = cur.fetchone()
+ assert res[0] == "(foo)"
+ assert res[1] == "{(foo)}"
+
+ info = TypeInfo.fetch(conn, "mytype")
+ info.register(conn)
+
+ cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""")
+ res = cur.fetchone()
+ assert res[0] == "(foo)"
+ assert res[1] == ["(foo)"]
+
+
+@pytest.mark.crdb("skip", reason="aclitem")
+def test_array_of_unknown_builtin(conn):
+ user = conn.execute("select user").fetchone()[0]
+ # we cannot load this type, but we understand it is an array
+ val = f"{user}=arwdDxt/{user}"
+ cur = conn.execute(f"select '{val}'::aclitem, array['{val}']::aclitem[]")
+ res = cur.fetchone()
+ assert cur.description[0].type_code == builtins["aclitem"].oid
+ assert res[0] == val
+ assert cur.description[1].type_code == builtins["aclitem"].array_oid
+ assert res[1] == [val]
+
+
+@pytest.mark.parametrize(
+ "num, type",
+ [
+ (0, "int2"),
+ (2**15 - 1, "int2"),
+ (-(2**15), "int2"),
+ (2**15, "int4"),
+ (2**31 - 1, "int4"),
+ (-(2**31), "int4"),
+ (2**31, "int8"),
+ (2**63 - 1, "int8"),
+ (-(2**63), "int8"),
+ (2**63, "numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_numbers_array(num, type, fmt_in):
+ for array in ([num], [1, num]):
+ tx = Transformer()
+ dumper = tx.get_dumper(array, fmt_in)
+ dumper.dump(array)
+ assert dumper.oid == builtins[type].array_oid
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ if wrapper is Decimal:
+ want_cls = Decimal
+ else:
+ assert wrapper.__mro__[1] in (int, float)
+ want_cls = wrapper.__mro__[1]
+
+ obj = [wrapper(1), wrapper(0), wrapper(-1), None]
+ cur = conn.cursor(binary=fmt_out)
+ got = cur.execute(f"select %{fmt_in.value}", [obj]).fetchone()[0]
+ assert got == obj
+ for i in got:
+ if i is not None:
+ assert type(i) is want_cls
+
+
+def test_mix_types(conn):
+ with pytest.raises(psycopg.DataError):
+ conn.execute("select %s", ([1, 0.5],))
+
+ with pytest.raises(psycopg.DataError):
+ conn.execute("select %s", ([1, Decimal("0.5")],))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list_mix(conn, fmt_in):
+ objs = list(range(3))
+ conn.execute("create table testarrays (col1 bigint[], col2 bigint[])")
+ # pro tip: don't get confused with the types
+ f1, f2 = conn.execute(
+ f"insert into testarrays values (%{fmt_in.value}, %{fmt_in.value}) returning *",
+ (objs, []),
+ ).fetchone()
+ assert f1 == objs
+ assert f2 == []
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table test (id serial primary key, data date[])")
+ with conn.transaction():
+ cur.execute(
+ f"insert into test (data) values (%{fmt_in.value}) returning id", ([],)
+ )
+ id = cur.fetchone()[0]
+ cur.execute("select data from test")
+ assert cur.fetchone() == ([],)
+
+ # test untyped list in a filter
+ cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([id],))
+ assert cur.fetchone()
+ cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([],))
+ assert not cur.fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_empty_list_after_choice(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table test (id serial primary key, data float[])")
+ cur.executemany(
+ f"insert into test (data) values (%{fmt_in.value})", [([1.0],), ([],)]
+ )
+ cur.execute("select data from test order by id")
+ assert cur.fetchall() == [([1.0],), ([],)]
+
+
+@pytest.mark.crdb_skip("geometric types")
+def test_dump_list_no_comma_separator(conn):
+ class Box:
+ def __init__(self, x1, y1, x2, y2):
+ self.coords = (x1, y1, x2, y2)
+
+ class BoxDumper(Dumper):
+
+ format = pq.Format.TEXT
+ oid = psycopg.postgres.types["box"].oid
+
+ def dump(self, box):
+ return ("(%s,%s),(%s,%s)" % box.coords).encode()
+
+ conn.adapters.register_dumper(Box, BoxDumper)
+
+ cur = conn.execute("select (%s::box)::text", (Box(1, 2, 3, 4),))
+ got = cur.fetchone()[0]
+ assert got == "(3,4),(1,2)"
+
+ cur = conn.execute(
+ "select (%s::box[])::text", ([Box(1, 2, 3, 4), Box(5, 4, 3, 2)],)
+ )
+ got = cur.fetchone()[0]
+ assert got == "{(3,4),(1,2);(5,4),(3,2)}"
+
+
+@pytest.mark.crdb_skip("geometric types")
+def test_load_array_no_comma_separator(conn):
+ cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]")
+ # Not parsed at the moment, but split ok on ; separator
+ assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"]
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_nested_array(conn, fmt_out):
+ dims = [3, 4, 5, 6]
+ a: List[Any] = list(range(prod(dims)))
+ for dim in dims[-1:0:-1]:
+ a = [a[i : i + dim] for i in range(0, len(a), dim)]
+
+ assert a[2][3][4][5] == prod(dims) - 1
+
+ sa = str(a).replace("[", "{").replace("]", "}")
+ got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0]
+ assert got == a
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "obj, want",
+ [
+ ("'[0:1]={a,b}'::text[]", ["a", "b"]),
+ ("'[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}'::int[]", [[[1, 2, 3], [4, 5, 6]]]),
+ ],
+)
+def test_array_with_bounds(conn, obj, want, fmt_out):
+ got = conn.execute(f"select {obj}", binary=fmt_out).fetchone()[0]
+ assert got == want
+
+
+@pytest.mark.crdb_skip("nested array")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_all_chars_with_bounds(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ c = chr(i)
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([c],))
+ assert cur.fetchone()[0] == ["a", "b", c]
+
+ a = list(map(chr, range(1, 256)))
+ a.append("\u20ac")
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", (a,))
+ assert cur.fetchone()[0] == ["a", "b"] + a
+
+ s = "".join(a)
+ cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([s],))
+ assert cur.fetchone()[0] == ["a", "b", s]
diff --git a/tests/types/test_bool.py b/tests/types/test_bool.py
new file mode 100644
index 0000000..edd4dad
--- /dev/null
+++ b/tests/types/test_bool.py
@@ -0,0 +1,47 @@
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+from psycopg.postgres import types as builtins
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("b", [True, False])
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ result = cur.execute(f"select %{fmt_in.value}", (b,)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].oid
+ assert result is b
+
+ result = cur.execute(f"select %{fmt_in.value}", ([b],)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].array_oid
+ assert result[0] is b
+
+
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode(
+ "ascii"
+ )
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+ assert cur.fetchone()[0] is val
+
+
+def test_quote_none(conn):
+
+ tx = Transformer()
+ assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL"
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+ assert cur.fetchone()[0] is None
diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py
new file mode 100644
index 0000000..47beecf
--- /dev/null
+++ b/tests/types/test_composite.py
@@ -0,0 +1,396 @@
+import pytest
+
+from psycopg import pq, postgres, sql
+from psycopg.adapt import PyFormat
+from psycopg.postgres import types as builtins
+from psycopg.types.range import Range
+from psycopg.types.composite import CompositeInfo, register_composite
+from psycopg.types.composite import TupleDumper, TupleBinaryDumper
+
+from ..utils import eur
+from ..fix_crdb import is_crdb, crdb_skip_message
+
+
+pytestmark = pytest.mark.crdb_skip("composite")
+
+tests_str = [
+ ("", ()),
+ # Funnily enough there's no way to represent (None,) in Postgres
+ ("null", ()),
+ ("null,null", (None, None)),
+ ("null, ''", (None, "")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ ("42", "foo", "ba,r", "ba'z", 'qu"x'),
+ ),
+ ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"')),
+]
+
+
+@pytest.mark.parametrize("rec, want", tests_str)
+def test_load_record(conn, want, rec):
+ cur = conn.cursor()
+ res = cur.execute(f"select row({rec})").fetchone()[0]
+ assert res == want
+
+
+@pytest.mark.parametrize("rec, obj", tests_str)
+def test_dump_tuple(conn, rec, obj):
+ cur = conn.cursor()
+ fields = [f"f{i} text" for i in range(len(obj))]
+ cur.execute(
+ f"""
+ drop type if exists tmptype;
+ create type tmptype as ({', '.join(fields)});
+ """
+ )
+ info = CompositeInfo.fetch(conn, "tmptype")
+ register_composite(info, conn)
+
+ res = conn.execute("select %s::tmptype", [obj]).fetchone()[0]
+ assert res == obj
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_all_chars(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
+ assert res == (chr(i),)
+
+ cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256)))
+ res = cur.fetchone()[0]
+ assert res == tuple(map(chr, range(1, 256)))
+
+ s = "".join(map(chr, range(1, 256)))
+ res = cur.execute("select row(%s::text)", [s]).fetchone()[0]
+ assert res == (s,)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_range(conn, fmt_in):
+ conn.execute(
+ """
+ drop type if exists tmptype;
+ create type tmptype as (num integer, range daterange, nums integer[])
+ """
+ )
+ info = CompositeInfo.fetch(conn, "tmptype")
+ register_composite(info, conn)
+
+ cur = conn.execute(
+ f"select pg_typeof(%{fmt_in.value})",
+ [info.python_type(10, Range(empty=True), [])],
+ )
+ assert cur.fetchone()[0] == "tmptype"
+
+
+@pytest.mark.parametrize(
+ "rec, want",
+ [
+ ("", ()),
+ ("null", (None,)), # Unlike text format, this is a thing
+ ("null,null", (None, None)),
+ ("null, ''", (None, b"")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ (42, b"foo", b"ba,r", b"ba'z", b'qu"x'),
+ ),
+ (
+ "'foo''', '''foo', '\"bar', 'bar\"' ",
+ (b"foo'", b"'foo", b'"bar', b'bar"'),
+ ),
+ (
+ "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ",
+ (10, None, 20.0, None, "foo", b"bar"),
+ ),
+ ],
+)
+def test_load_record_binary(conn, want, rec):
+ cur = conn.cursor(binary=True)
+ res = cur.execute(f"select row({rec})").fetchone()[0]
+ assert res == want
+ for o1, o2 in zip(res, want):
+ assert type(o1) is type(o2)
+
+
+@pytest.fixture(scope="session")
+def testcomp(svcconn):
+ if is_crdb(svcconn):
+ pytest.skip(crdb_skip_message("composite"))
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ create schema if not exists testschema;
+
+ drop type if exists testcomp cascade;
+ drop type if exists testschema.testcomp cascade;
+
+ create type testcomp as (foo text, bar int8, baz float8);
+ create type testschema.testcomp as (foo text, bar int8, qux bool);
+ """
+ )
+ return CompositeInfo.fetch(svcconn, "testcomp")
+
+
+fetch_cases = [
+ (
+ "testcomp",
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ "testschema.testcomp",
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ (
+ sql.Identifier("testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ sql.Identifier("testschema", "testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+]
+
+
+@pytest.mark.parametrize("name, fields", fetch_cases)
+def test_fetch_info(conn, testcomp, name, fields):
+ info = CompositeInfo.fetch(conn, name)
+ assert info.name == "testcomp"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.field_names) == 3
+ assert len(info.field_types) == 3
+ for i, (name, t) in enumerate(fields):
+ assert info.field_names[i] == name
+ assert info.field_types[i] == builtins[t].oid
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, fields", fetch_cases)
+async def test_fetch_info_async(aconn, testcomp, name, fields):
+ info = await CompositeInfo.fetch(aconn, name)
+ assert info.name == "testcomp"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.field_names) == 3
+ assert len(info.field_types) == 3
+ for i, (name, t) in enumerate(fields):
+ assert info.field_names[i] == name
+ assert info.field_types[i] == builtins[t].oid
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_tuple_all_chars(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ for i in range(1, 256):
+ (res,) = cur.execute(
+ f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}::testcomp",
+ (i, (chr(i), 1, 1.0)),
+ ).fetchone()
+ assert res is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_composite_all_chars(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ register_composite(testcomp, cur)
+ factory = testcomp.python_type
+ for i in range(1, 256):
+ obj = factory(chr(i), 1, 1.0)
+ (res,) = cur.execute(
+ f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}", (i, obj)
+ ).fetchone()
+ assert res is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_composite_null(conn, fmt_in, testcomp):
+ cur = conn.cursor()
+ register_composite(testcomp, cur)
+ factory = testcomp.python_type
+
+ obj = factory("foo", 1, None)
+ rec = cur.execute(
+ f"""
+ select row('foo', 1, NULL)::testcomp = %(obj){fmt_in.value},
+ %(obj){fmt_in.value}::text
+ """,
+ {"obj": obj},
+ ).fetchone()
+ assert rec[0] is True, rec[1]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_composite(conn, testcomp, fmt_out):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert res.foo == "hello"
+ assert res.bar == 10
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_composite_factory(conn, testcomp, fmt_out):
+ info = CompositeInfo.fetch(conn, "testcomp")
+
+ class MyThing:
+ def __init__(self, *args):
+ self.foo, self.bar, self.baz = args
+
+ register_composite(info, conn, factory=MyThing)
+ assert info.python_type is MyThing
+
+ cur = conn.cursor(binary=fmt_out)
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert isinstance(res, MyThing)
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)
+
+
+def test_register_scope(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert postgres.adapters._loaders[fmt].pop(oid)
+
+ for f in PyFormat:
+ assert postgres.adapters._dumpers[f].pop(info.python_type)
+
+ cur = conn.cursor()
+ register_composite(info, cur)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert oid not in postgres.adapters._loaders[fmt]
+ assert oid not in conn.adapters._loaders[fmt]
+ assert oid in cur.adapters._loaders[fmt]
+
+ register_composite(info, conn)
+ for fmt in pq.Format:
+ for oid in (info.oid, info.array_oid):
+ assert oid not in postgres.adapters._loaders[fmt]
+ assert oid in conn.adapters._loaders[fmt]
+
+
+def test_type_dumper_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+ assert issubclass(info.python_type, tuple)
+ assert info.python_type.__name__ == "testcomp"
+ d = conn.adapters.get_dumper(info.python_type, "s")
+ assert issubclass(d, TupleDumper)
+ assert d is not TupleDumper
+
+ tc = info.python_type("foo", 42, 3.14)
+ cur = conn.execute("select pg_typeof(%s)", [tc])
+ assert cur.fetchone()[0] == "testcomp"
+
+
+def test_type_dumper_registered_binary(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+ register_composite(info, conn)
+ assert issubclass(info.python_type, tuple)
+ assert info.python_type.__name__ == "testcomp"
+ d = conn.adapters.get_dumper(info.python_type, "b")
+ assert issubclass(d, TupleBinaryDumper)
+ assert d is not TupleBinaryDumper
+
+ tc = info.python_type("foo", 42, 3.14)
+ cur = conn.execute("select pg_typeof(%b)", [tc])
+ assert cur.fetchone()[0] == "testcomp"
+
+
+def test_callable_dumper_not_registered(conn, testcomp):
+ info = CompositeInfo.fetch(conn, "testcomp")
+
+ def fac(*args):
+ return args + (args[-1],)
+
+ register_composite(info, conn, factory=fac)
+ assert info.python_type is None
+
+ # but the loader is registered
+ cur = conn.execute("select '(foo,42,3.14)'::testcomp")
+ assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14)
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="composite"):
+ register_composite(None, conn) # type: ignore[arg-type]
+
+
+def test_invalid_fields_names(conn):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ f"""
+ create type "a-b" as ("c-d" text, "{eur}" int);
+ create type "-x-{eur}" as ("w-ww" "a-b", "0" int);
+ """
+ )
+ ab = CompositeInfo.fetch(conn, '"a-b"')
+ x = CompositeInfo.fetch(conn, f'"-x-{eur}"')
+ register_composite(ab, conn)
+ register_composite(x, conn)
+ obj = x.python_type(ab.python_type("foo", 10), 20)
+ conn.execute(f"""create table meh (wat "-x-{eur}")""")
+ conn.execute("insert into meh values (%s)", [obj])
+ got = conn.execute("select wat from meh").fetchone()[0]
+ assert obj == got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "1", "'"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ sql.SQL("create type {name} as (foo text)").format(name=sql.Identifier(name))
+ )
+ info = CompositeInfo.fetch(conn, sql.Identifier(name).as_string(conn))
+ register_composite(info, conn)
+ obj = info.python_type("hello")
+ assert sql.Literal(obj).as_string(conn) == f"'(hello)'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ got = cur.fetchone()[0]
+ assert got == obj
+ assert type(got) is type(obj)
+
+
+@pytest.mark.parametrize(
+ "name, attr",
+ [
+ ("a-b", "a_b"),
+ (f"{eur}", "f_"),
+ ("üåäö", "üåäö"),
+ ("order", "order"),
+ ("1", "f1"),
+ ],
+)
+def test_literal_invalid_attr(conn, name, attr):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(
+ sql.SQL("create type test_attr as ({name} text)").format(
+ name=sql.Identifier(name)
+ )
+ )
+ info = CompositeInfo.fetch(conn, "test_attr")
+ register_composite(info, conn)
+ obj = info.python_type("hello")
+ assert getattr(obj, attr) == "hello"
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ got = cur.fetchone()[0]
+ assert got == obj
+ assert type(got) is type(obj)
diff --git a/tests/types/test_datetime.py b/tests/types/test_datetime.py
new file mode 100644
index 0000000..11fe493
--- /dev/null
+++ b/tests/types/test_datetime.py
@@ -0,0 +1,813 @@
+import datetime as dt
+
+import pytest
+
+from psycopg import DataError, pq, sql
+from psycopg.adapt import PyFormat
+
+crdb_skip_datestyle = pytest.mark.crdb("skip", reason="set datestyle/intervalstyle")
+crdb_skip_negative_interval = pytest.mark.crdb("skip", reason="negative interval")
+crdb_skip_invalid_tz = pytest.mark.crdb(
+ "skip", reason="crdb doesn't allow invalid timezones"
+)
+
+datestyles_in = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["DMY", "MDY", "YMD"]
+]
+datestyles_out = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["ISO", "Postgres", "SQL", "German"]
+]
+
+intervalstyles = [
+ pytest.param(datestyle, marks=crdb_skip_datestyle)
+ for datestyle in ["sql_standard", "postgres", "postgres_verbose", "iso_8601"]
+]
+
+
+class TestDate:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_date(self, conn, val, expr, fmt_in):
+ val = as_date(val)
+ cur = conn.cursor()
+ cur.execute(f"select '{expr}'::date = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(
+ sql.SQL("select {}::date = {}").format(
+ sql.Literal(val), sql.Placeholder(format=fmt_in)
+ ),
+ (val,),
+ )
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_date_datestyle(self, conn, datestyle_in):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO,{datestyle_in}")
+ cur.execute("select 'epoch'::date + 1 = %t", (dt.date(1970, 1, 2),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_date(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::date")
+ assert cur.fetchone()[0] == as_date(val)
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_date_datestyle(self, conn, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select '2000-01-02'::date")
+ assert cur.fetchone()[0] == dt.date(2000, 1, 2)
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_date_overflow(self, conn, val, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1))
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_date_overflow_binary(self, conn, val):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1))
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ overflow_samples = [
+ ("-infinity", "date too small"),
+ ("1000-01-01 BC", "date too small"),
+ ("10000-01-01", "date too large"),
+ ("infinity", "date too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_load_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::date", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_load_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::date", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ def test_infinity_date_example(self, conn):
+ # NOTE: this is an example in the docs. Make sure it doesn't regress when
+ # adding binary datetime adapters
+ from datetime import date
+ from psycopg.types.datetime import DateLoader, DateDumper
+
+ class InfDateDumper(DateDumper):
+ def dump(self, obj):
+ if obj == date.max:
+ return b"infinity"
+ else:
+ return super().dump(obj)
+
+ class InfDateLoader(DateLoader):
+ def load(self, data):
+ if data == b"infinity":
+ return date.max
+ else:
+ return super().load(data)
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(date, InfDateDumper)
+ cur.adapters.register_loader("date", InfDateLoader)
+
+ rec = cur.execute(
+ "SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]
+ ).fetchone()
+ assert rec == ("2020-12-31", "infinity")
+ rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone()
+ assert rec == (date(2020, 12, 31), date(9999, 12, 31))
+
+
+class TestDatetime:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "0001-01-01 00:00"),
+ ("258,1,8,1,12,32,358261", "0258-1-8 1:12:32.358261"),
+ ("1000,1,1,0,0", "1000-01-01 00:00"),
+ ("2000,1,1,0,0", "2000-01-01 00:00"),
+ ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"),
+ ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"),
+ ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"),
+ ("2000,1,1,0,0,0,1", "2000-01-01 00:00:00.000001"),
+ ("2034,02,03,23,34,27,951357", "2034-02-03 23:34:27.951357"),
+ ("2200,1,1,0,0,0,1", "2200-01-01 00:00:00.000001"),
+ ("2300,1,1,0,0,0,1", "2300-01-01 00:00:00.000001"),
+ ("7000,1,1,0,0,0,1", "7000-01-01 00:00:00.000001"),
+ ("max", "9999-12-31 23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetime(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select %{fmt_in.value}", (as_dt(val),))
+ cur.execute(f"select '{expr}'::timestamp = %{fmt_in.value}", (as_dt(val),))
+ cur.execute(
+ f"""
+ select '{expr}'::timestamp = %(val){fmt_in.value},
+ '{expr}', %(val){fmt_in.value}::text
+ """,
+ {"val": as_dt(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (want, got)
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_datetime_datestyle(self, conn, datestyle_in):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO, {datestyle_in}")
+ cur.execute(
+ "select 'epoch'::timestamp + '1d 3h 4m 5s'::interval = %t",
+ (dt.datetime(1970, 1, 2, 3, 4, 5),),
+ )
+ assert cur.fetchone()[0] is True
+
+ load_datetime_samples = [
+ ("min", "0001-01-01"),
+ ("1000,1,1", "1000-01-01"),
+ ("2000,1,1", "2000-01-01"),
+ ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"),
+ ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"),
+ ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"),
+ ("2000,12,31", "2000-12-31"),
+ ("3000,1,1", "3000-01-01"),
+ ("max", "9999-12-31 23:59:59.999999"),
+ ]
+
+ @pytest.mark.parametrize("val, expr", load_datetime_samples)
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select '{expr}'::timestamp")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize("val, expr", load_datetime_samples)
+ def test_load_datetime_binary(self, conn, val, expr):
+ cur = conn.cursor(binary=True)
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select '{expr}'::timestamp")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ def test_load_datetime_overflow(self, conn, val, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute(
+ "select %t::timestamp + %s * '1s'::interval",
+ (as_dt(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_datetime_overflow_binary(self, conn, val):
+ cur = conn.cursor(binary=True)
+ cur.execute(
+ "select %t::timestamp + %s * '1s'::interval",
+ (as_dt(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ overflow_samples = [
+ ("-infinity", "timestamp too small"),
+ ("1000-01-01 12:00 BC", "timestamp too small"),
+ ("10000-01-01 12:00", "timestamp too large"),
+ ("infinity", "timestamp too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::timestamp", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::timestamp", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @crdb_skip_datestyle
+ def test_load_all_month_names(self, conn):
+ cur = conn.cursor(binary=False)
+ cur.execute("set datestyle = 'Postgres'")
+ for i in range(12):
+ d = dt.datetime(2000, i + 1, 15)
+ cur.execute("select %s", [d])
+ assert cur.fetchone()[0] == d
+
+
+class TestDateTimeTz:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min~-2", "0001-01-01 00:00-02:00"),
+ ("min~-12", "0001-01-01 00:00-12:00"),
+ (
+ "258,1,8,1,12,32,358261~1:2:3",
+ "0258-1-8 1:12:32.358261+01:02:03",
+ ),
+ ("1000,1,1,0,0~2", "1000-01-01 00:00+2"),
+ ("2000,1,1,0,0~2", "2000-01-01 00:00+2"),
+ ("2000,1,1,0,0~12", "2000-01-01 00:00+12"),
+ ("2000,1,1,0,0~-12", "2000-01-01 00:00-12"),
+ ("2000,1,1,0,0~01:02:03", "2000-01-01 00:00+01:02:03"),
+ ("2000,1,1,0,0~-01:02:03", "2000-01-01 00:00-01:02:03"),
+ ("2000,12,31,23,59,59,999999~2", "2000-12-31 23:59:59.999999+2"),
+ (
+ "2034,02,03,23,34,27,951357~-4:27",
+ "2034-02-03 23:34:27.951357-04:27",
+ ),
+ ("2300,1,1,0,0,0,1~1", "2300-01-01 00:00:00.000001+1"),
+ ("3000,1,1,0,0~2", "3000-01-01 00:00+2"),
+ ("7000,1,1,0,0,0,1~-1:2:3", "7000-01-01 00:00:00.000001-01:02:03"),
+ ("max~2", "9999-12-31 23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetimetz(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(
+ f"""
+ select '{expr}'::timestamptz = %(val){fmt_in.value},
+ '{expr}', %(val){fmt_in.value}::text
+ """,
+ {"val": as_dt(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (want, got)
+
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_dump_datetimetz_datestyle(self, conn, datestyle_in):
+ tzinfo = dt.timezone(dt.timedelta(hours=2))
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = ISO, {datestyle_in}")
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(
+ "select 'epoch'::timestamptz + '1d 3h 4m 5.678s'::interval = %t",
+ (dt.datetime(1970, 1, 2, 5, 4, 5, 678000, tzinfo=tzinfo),),
+ )
+ assert cur.fetchone()[0] is True
+
+ load_datetimetz_samples = [
+ ("2000,1,1~2", "2000-01-01", "-02:00"),
+ ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"),
+ ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"),
+ ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"),
+ ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"),
+ ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"),
+ ("2000,12,31~2", "2000-12-31", "-02:00"),
+ ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"),
+ ]
+
+ @crdb_skip_datestyle
+ @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
+ @pytest.mark.parametrize("datestyle_out", ["ISO"])
+ def test_load_datetimetz(self, conn, val, expr, timezone, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, DMY")
+ cur.execute(f"set timezone to '{timezone}'")
+ got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
+ assert got == as_dt(val)
+
+ @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
+ def test_load_datetimetz_binary(self, conn, val, expr, timezone):
+ cur = conn.cursor(binary=True)
+ cur.execute(f"set timezone to '{timezone}'")
+ got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
+ assert got == as_dt(val)
+
+ @pytest.mark.xfail # parse timezone names
+ @crdb_skip_datestyle
+ @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
+ @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
+ @pytest.mark.parametrize("datestyle_in", datestyles_in)
+ def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(f"select '{expr}'::timestamptz")
+ assert cur.fetchone()[0] == as_dt(val)
+
+ @pytest.mark.parametrize(
+ "tzname, expr, tzoff",
+ [
+ ("UTC", "2000-1-1", 0),
+ ("UTC", "2000-7-1", 0),
+ ("Europe/Rome", "2000-1-1", 3600),
+ ("Europe/Rome", "2000-7-1", 7200),
+ ("Europe/Rome", "1000-1-1", 2996),
+ pytest.param("NOSUCH0", "2000-1-1", 0, marks=crdb_skip_invalid_tz),
+ pytest.param("0", "2000-1-1", 0, marks=crdb_skip_invalid_tz),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_datetimetz_tz(self, conn, fmt_out, tzname, expr, tzoff):
+ conn.execute("select set_config('TimeZone', %s, true)", [tzname])
+ cur = conn.cursor(binary=fmt_out)
+ ts = cur.execute("select %s::timestamptz", [expr]).fetchone()[0]
+ assert ts.utcoffset().total_seconds() == tzoff
+
+ @pytest.mark.parametrize(
+ "val, type",
+ [
+ ("2000,1,2,3,4,5,6", "timestamp"),
+ ("2000,1,2,3,4,5,6~0", "timestamptz"),
+ ("2000,1,2,3,4,5,6~2", "timestamptz"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_datetime_tz_or_not_tz(self, conn, val, type, fmt_in):
+ val = as_dt(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value}
+ """,
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '2000-01-01 01:02:03.123456-10:20'::timestamptz,
+ '11111111'::int4
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["timestamptz", "int4"])
+ rec = copy.read_row()
+
+ tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+ want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz)
+ assert rec[0] == want
+ assert rec[1] == 11111111
+
+ overflow_samples = [
+ ("-infinity", "timestamp too small"),
+ ("1000-01-01 12:00+00 BC", "timestamp too small"),
+ ("10000-01-01 12:00+00", "timestamp too large"),
+ ("infinity", "timestamp too large"),
+ ]
+
+ @pytest.mark.parametrize("datestyle_out", datestyles_out)
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message(self, conn, datestyle_out, val, msg):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, YMD")
+ cur.execute("select %s::timestamptz", (val,))
+ if datestyle_out == "ISO":
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+ else:
+ with pytest.raises(NotImplementedError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize("val, msg", overflow_samples)
+ def test_overflow_message_binary(self, conn, val, msg):
+ cur = conn.cursor(binary=True)
+ cur.execute("select %s::timestamptz", (val,))
+ with pytest.raises(DataError) as excinfo:
+ cur.fetchone()[0]
+ assert msg in str(excinfo.value)
+
+ @pytest.mark.parametrize(
+ "valname, tzval, tzname",
+ [
+ ("max", "-06", "America/Chicago"),
+ ("min", "+09:18:59", "Asia/Tokyo"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_max_with_timezone(self, conn, fmt_out, valname, tzval, tzname):
+ # This happens e.g. in Django when it caches forever.
+ # e.g. see Django test cache.tests.DBCacheTests.test_forever_timeout
+ val = getattr(dt.datetime, valname).replace(microsecond=0)
+ tz = dt.timezone(as_tzoffset(tzval))
+ want = val.replace(tzinfo=tz)
+
+ conn.execute("set timezone to '%s'" % tzname)
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::timestamptz", [str(val) + tzval])
+ got = cur.fetchone()[0]
+
+ assert got == want
+
+ extra = "1 day" if valname == "max" else "-1 day"
+ with pytest.raises(DataError):
+ cur.execute(
+ "select %s::timestamptz + %s::interval",
+ [str(val) + tzval, extra],
+ )
+ got = cur.fetchone()[0]
+
+
+class TestTime:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "00:00"),
+ ("10,20,30,40", "10:20:30.000040"),
+ ("max", "23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_time(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select '{expr}'::time = %(val){fmt_in.value},
+ '{expr}'::time::text, %(val){fmt_in.value}::text
+ """,
+ {"val": as_time(val)},
+ )
+ ok, want, got = cur.fetchone()
+ assert ok, (got, want)
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "00:00"),
+ ("1,2", "01:02"),
+ ("10,20", "10:20"),
+ ("10,20,30", "10:20:30"),
+ ("10,20,30,40", "10:20:30.000040"),
+ ("max", "23:59:59.999999"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_time(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::time")
+ assert cur.fetchone()[0] == as_time(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_time_24(self, conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select '24:00'::time")
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+
+class TestTimeTz:
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min~-10", "00:00-10:00"),
+ ("min~+12", "00:00+12:00"),
+ ("10,20,30,40~-2", "10:20:30.000040-02:00"),
+ ("10,20,30,40~0", "10:20:30.000040Z"),
+ ("10,20,30,40~+2:30", "10:20:30.000040+02:30"),
+ ("max~-12", "23:59:59.999999-12:00"),
+ ("max~+12", "23:59:59.999999+12:00"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_timetz(self, conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute("set timezone to '-02:00'")
+ cur.execute(f"select '{expr}'::timetz = %{fmt_in.value}", (as_time(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr, timezone",
+ [
+ ("0,0~-12", "00:00", "12:00"),
+ ("0,0~12", "00:00", "-12:00"),
+ ("3,4,5,6~2", "03:04:05.000006", "-02:00"),
+ ("3,4,5,6~7:8", "03:04:05.000006", "-07:08"),
+ ("3,0,0,456789~2", "03:00:00.456789", "-02:00"),
+ ("3,0,0,456789~-2", "03:00:00.456789", "+02:00"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_timetz(self, conn, val, timezone, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"set timezone to '{timezone}'")
+ cur.execute(f"select '{expr}'::timetz")
+ assert cur.fetchone()[0] == as_time(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_timetz_24(self, conn, fmt_out):
+ cur = conn.cursor()
+ cur.execute("select '24:00'::timetz")
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.parametrize(
+ "val, type",
+ [
+ ("3,4,5,6", "time"),
+ ("3,4,5,6~0", "timetz"),
+ ("3,4,5,6~2", "timetz"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_in", PyFormat)
+ def test_dump_time_tz_or_not_tz(self, conn, val, type, fmt_in):
+ val = as_time(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"""
+ select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value}
+ """,
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '01:02:03.123456-10:20'::timetz,
+ '11111111'::int4
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["timetz", "int4"])
+ rec = copy.read_row()
+
+ tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+ want = dt.time(1, 2, 3, 123456, tzinfo=tz)
+ assert rec[0] == want
+ assert rec[1] == 11111111
+
+
+class TestInterval:
+ dump_timedelta_samples = [
+ ("min", "-999999999 days"),
+ ("1d", "1 day"),
+ pytest.param("-1d", "-1 day", marks=crdb_skip_negative_interval),
+ ("1s", "1 s"),
+ pytest.param("-1s", "-1 s", marks=crdb_skip_negative_interval),
+ pytest.param("-1m", "-0.000001 s", marks=crdb_skip_negative_interval),
+ ("1m", "0.000001 s"),
+ ("max", "999999999 days 23:59:59.999999"),
+ ]
+
+ @pytest.mark.parametrize("val, expr", dump_timedelta_samples)
+ @pytest.mark.parametrize("intervalstyle", intervalstyles)
+ def test_dump_interval(self, conn, val, expr, intervalstyle):
+ cur = conn.cursor()
+ cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+ cur.execute(f"select '{expr}'::interval = %t", (as_td(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize("val, expr", dump_timedelta_samples)
+ def test_dump_interval_binary(self, conn, val, expr):
+ cur = conn.cursor()
+ cur.execute(f"select '{expr}'::interval = %b", (as_td(val),))
+ assert cur.fetchone()[0] is True
+
+ @pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("1s", "1 sec"),
+ ("-1s", "-1 sec"),
+ ("60s", "1 min"),
+ ("3600s", "1 hour"),
+ ("1s,1000m", "1.001 sec"),
+ ("1s,1m", "1.000001 sec"),
+ ("1d", "1 day"),
+ ("-10d", "-10 day"),
+ ("1d,1s,1m", "1 day 1.000001 sec"),
+ ("-86399s,-999999m", "-23:59:59.999999"),
+ ("-3723s,-400000m", "-1:2:3.4"),
+ ("3723s,400000m", "1:2:3.4"),
+ ("86399s,999999m", "23:59:59.999999"),
+ ("30d", "30 day"),
+ ("365d", "1 year"),
+ ("-365d", "-1 year"),
+ ("-730d", "-2 years"),
+ ("1460d", "4 year"),
+ ("30d", "1 month"),
+ ("-30d", "-1 month"),
+ ("60d", "2 month"),
+ ("-90d", "-3 month"),
+ ],
+ )
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_interval(self, conn, val, expr, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select '{expr}'::interval")
+ assert cur.fetchone()[0] == as_td(val)
+
+ @crdb_skip_datestyle
+ @pytest.mark.xfail # weird interval outputs
+ @pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")])
+ @pytest.mark.parametrize(
+ "intervalstyle",
+ ["sql_standard", "postgres_verbose", "iso_8601"],
+ )
+ def test_load_interval_intervalstyle(self, conn, val, expr, intervalstyle):
+ cur = conn.cursor(binary=False)
+ cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+ cur.execute(f"select '{expr}'::interval")
+ assert cur.fetchone()[0] == as_td(val)
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ @pytest.mark.parametrize("val", ["min", "max"])
+ def test_load_interval_overflow(self, conn, val, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(
+ "select %s + %s * '1s'::interval",
+ (as_td(val), -1 if val == "min" else 1),
+ )
+ with pytest.raises(DataError):
+ cur.fetchone()[0]
+
+ @pytest.mark.crdb_skip("copy")
+ def test_load_copy(self, conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy(
+ """
+ copy (
+ select
+ '1 days +00:00:01.000001'::interval,
+ 'foo bar'::text
+ ) to stdout
+ """
+ ) as copy:
+ copy.set_types(["interval", "text"])
+ rec = copy.read_row()
+
+ want = dt.timedelta(days=1, seconds=1, microseconds=1)
+ assert rec[0] == want
+ assert rec[1] == "foo bar"
+
+
+#
+# Support
+#
+
+
+def as_date(s):
+ return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
+
+
+def as_time(s):
+ if "~" in s:
+ s, off = s.split("~")
+ else:
+ off = None
+
+ if "," in s:
+ rv = dt.time(*map(int, s.split(","))) # type: ignore[arg-type]
+ else:
+ rv = getattr(dt.time, s)
+ if off:
+ rv = rv.replace(tzinfo=as_tzinfo(off))
+
+ return rv
+
+
+def as_dt(s):
+ if "~" not in s:
+ return as_naive_dt(s)
+
+ s, off = s.split("~")
+ rv = as_naive_dt(s)
+ off = as_tzoffset(off)
+ rv = (rv - off).replace(tzinfo=dt.timezone.utc)
+ return rv
+
+
+def as_naive_dt(s):
+ if "," in s:
+ rv = dt.datetime(*map(int, s.split(","))) # type: ignore[arg-type]
+ else:
+ rv = getattr(dt.datetime, s)
+
+ return rv
+
+
+def as_tzoffset(s):
+ if s.startswith("-"):
+ mul = -1
+ s = s[1:]
+ else:
+ mul = 1
+
+ fields = ("hours", "minutes", "seconds")
+ return mul * dt.timedelta(**dict(zip(fields, map(int, s.split(":")))))
+
+
+def as_tzinfo(s):
+ off = as_tzoffset(s)
+ return dt.timezone(off)
+
+
+def as_td(s):
+ if s in ("min", "max"):
+ return getattr(dt.timedelta, s)
+
+ suffixes = {"d": "days", "s": "seconds", "m": "microseconds"}
+ kwargs = {}
+ for part in s.split(","):
+ kwargs[suffixes[part[-1]]] = int(part[:-1])
+
+ return dt.timedelta(**kwargs)
diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py
new file mode 100644
index 0000000..8dfb6d4
--- /dev/null
+++ b/tests/types/test_enum.py
@@ -0,0 +1,363 @@
+from enum import Enum, auto
+
+import pytest
+
+from psycopg import pq, sql, errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
+from psycopg.types.enum import EnumInfo, register_enum
+
+from ..fix_crdb import crdb_encoding
+
+
+class PureTestEnum(Enum):
+ FOO = auto()
+ BAR = auto()
+ BAZ = auto()
+
+
+class StrTestEnum(str, Enum):
+ ONE = "ONE"
+ TWO = "TWO"
+ THREE = "THREE"
+
+
+NonAsciiEnum = Enum(
+ "NonAsciiEnum",
+ {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"},
+ type=str,
+)
+
+
+class IntTestEnum(int, Enum):
+ ONE = 1
+ TWO = 2
+ THREE = 3
+
+
+enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
+encodings = ["utf8", crdb_encoding("latin1")]
+
+
+@pytest.fixture(scope="session", autouse=True)
+def make_test_enums(request, svcconn):
+ for enum in enum_cases + [NonAsciiEnum]:
+ ensure_enum(enum, svcconn)
+
+
+def ensure_enum(enum, conn):
+ name = enum.__name__.lower()
+ labels = list(enum.__members__)
+ conn.execute(
+ sql.SQL(
+ """
+ drop type if exists {name};
+ create type {name} as enum ({labels});
+ """
+ ).format(name=sql.Identifier(name), labels=sql.SQL(",").join(labels))
+ )
+ return name, enum, labels
+
+
+def test_fetch_info(conn):
+ info = EnumInfo.fetch(conn, "StrTestEnum")
+ assert info.name == "strtestenum"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.labels) == len(StrTestEnum)
+ assert info.labels == list(StrTestEnum.__members__)
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_async(aconn):
+ info = await EnumInfo.fetch(aconn, "PureTestEnum")
+ assert info.name == "puretestenum"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.labels) == len(PureTestEnum)
+ assert info.labels == list(PureTestEnum.__members__)
+
+
+def test_register_makes_a_type(conn):
+ info = EnumInfo.fetch(conn, "IntTestEnum")
+ assert info
+ assert info.enum is None
+ register_enum(info, context=conn)
+ assert info.enum is not None
+ assert [e.name for e in info.enum] == list(IntTestEnum.__members__)
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum=enum)
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
+ enum = NonAsciiEnum
+ conn.execute(f"set client_encoding to {encoding}")
+
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum=enum)
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+ conn.execute("set client_encoding to sql_ascii")
+
+ for label in info.labels:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == enum[label]
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out):
+ enum = NonAsciiEnum
+ conn.execute(f"set client_encoding to {encoding}")
+
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out):
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+ conn.execute("set client_encoding to sql_ascii")
+
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_dumper(conn, enum, fmt_in, fmt_out):
+ for item in enum:
+ if enum is PureTestEnum:
+ want = item.name
+ else:
+ want = item.value
+
+ cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
+ for item in NonAsciiEnum:
+ cur = conn.execute(f"select %{fmt_in.value}", [item.value], binary=fmt_out)
+ assert cur.fetchone()[0] == item.value
+
+
+@pytest.mark.parametrize("enum", enum_cases)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_loader(conn, enum, fmt_in, fmt_out):
+ for label in enum.__members__:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out
+ )
+ want = enum[label].name
+ if fmt_out == pq.Format.BINARY:
+ want = want.encode()
+ assert cur.fetchone()[0] == want
+
+
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
+
+ for label in NonAsciiEnum.__members__:
+ cur = conn.execute(
+ f"select %{fmt_in.value}::nonasciienum", [label], binary=fmt_out
+ )
+ if fmt_out == pq.Format.TEXT:
+ assert cur.fetchone()[0] == label
+ else:
+ assert cur.fetchone()[0] == label.encode(encoding)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, fmt_in, fmt_out):
+ enum = PureTestEnum
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ labels = list(enum.__members__)
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out
+ )
+ assert cur.fetchone()[0] == list(enum)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_dumper(conn, fmt_in, fmt_out):
+ enum = StrTestEnum
+ info = EnumInfo.fetch(conn, enum.__name__)
+ register_enum(info, conn, enum)
+
+ cur = conn.execute(f"select %{fmt_in.value}::text[]", [list(enum)], binary=fmt_out)
+ assert cur.fetchone()[0] == list(enum.__members__)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_array_loader(conn, fmt_in, fmt_out):
+ enum = IntTestEnum
+ info = TypeInfo.fetch(conn, enum.__name__)
+ info.register(conn)
+ labels = list(enum.__members__)
+ cur = conn.execute(
+ f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out
+ )
+ if fmt_out == pq.Format.TEXT:
+ assert cur.fetchone()[0] == labels
+ else:
+ assert cur.fetchone()[0] == [item.encode() for item in labels]
+
+
+def test_enum_error(conn):
+ conn.autocommit = True
+
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, StrTestEnum)
+
+ with pytest.raises(e.DataError):
+ conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone()
+ with pytest.raises(e.DataError):
+ conn.execute("select 'BAR'::puretestenum").fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "mapping",
+ [
+ {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"},
+ [
+ (StrTestEnum.ONE, "FOO"),
+ (StrTestEnum.TWO, "BAR"),
+ (StrTestEnum.THREE, "BAZ"),
+ ],
+ ],
+)
+def test_remap(conn, fmt_in, fmt_out, mapping):
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, StrTestEnum, mapping=mapping)
+
+ for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]:
+ cur = conn.execute(f"select %{fmt_in.value}::text", [StrTestEnum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out)
+ assert cur.fetchone()[0] is StrTestEnum[member]
+
+
+def test_remap_rename(conn):
+ enum = Enum("RenamedEnum", "FOO BAR QUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"})
+
+ for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_python(conn):
+ enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]}
+ register_enum(info, conn, enum, mapping=mapping)
+
+ for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+
+ for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_postgres(conn):
+ enum = Enum("SmallerEnum", "FOO")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")]
+ register_enum(info, conn, enum, mapping=mapping)
+
+ cur = conn.execute("select %s::text", [enum.FOO])
+ assert cur.fetchone()[0] == "BAZ"
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum.FOO
+
+
+def test_remap_by_value(conn):
+ enum = Enum( # type: ignore
+ "ByValue",
+ {m.lower(): m for m in PureTestEnum.__members__},
+ )
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={m: m.value for m in enum})
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute("select %s::text", [enum[label.lower()]])
+ assert cur.fetchone()[0] == label
+
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[label.lower()]
diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py
new file mode 100644
index 0000000..5142d58
--- /dev/null
+++ b/tests/types/test_hstore.py
@@ -0,0 +1,107 @@
+import pytest
+
+import psycopg
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import HstoreLoader, register_hstore
+
+pytestmark = pytest.mark.crdb_skip("hstore")
+
+
+@pytest.mark.parametrize(
+ "s, d",
+ [
+ ("", {}),
+ ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}),
+ ('"a" => "1" , "b" => "2"', {"a": "1", "b": "2"}),
+ ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}),
+ (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}),
+ ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}),
+ ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}),
+ (r'"a\\"=>"1"', {"a\\": "1"}),
+ (r'"a\""=>"1"', {'a"': "1"}),
+ (r'"a\\\""=>"1"', {r"a\"": "1"}),
+ (r'"a\\\\\""=>"1"', {r'a\\"': "1"}),
+ ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}),
+ ],
+)
+def test_parse_ok(s, d):
+ loader = HstoreLoader(0, None)
+ assert loader.load(s.encode()) == d
+
+
+@pytest.mark.parametrize(
+ "s",
+ [
+ "a",
+ '"a"',
+ r'"a\\""=>"1"',
+ r'"a\\\\""=>"1"',
+ '"a=>"1"',
+ '"a"=>"1", "b"=>NUL',
+ ],
+)
+def test_parse_bad(s):
+ with pytest.raises(psycopg.DataError):
+ loader = HstoreLoader(0, None)
+ loader.load(s.encode())
+
+
+def test_register_conn(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ register_hstore(info, conn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_curs(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ cur = conn.cursor()
+ register_hstore(info, cur)
+ assert conn.adapters.types.get(info.oid) is None
+ assert cur.adapters.types[info.oid].name == "hstore"
+
+ cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters):
+ info = TypeInfo.fetch(svcconn, "hstore")
+ register_hstore(info)
+ assert psycopg.adapters.types[info.oid].name == "hstore"
+
+ assert svcconn.adapters.types.get(info.oid) is None
+ conn = conn_cls.connect(dsn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+ conn.close()
+
+
+ab = list(map(chr, range(32, 128)))
+samp = [
+ {},
+ {"a": "b", "c": None},
+ dict(zip(ab, ab)),
+ {"".join(ab): "".join(ab)},
+]
+
+
+@pytest.mark.parametrize("d", samp)
+def test_roundtrip(hstore, conn, d):
+ register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
+ d1 = conn.execute("select %s", [d]).fetchone()[0]
+ assert d == d1
+
+
+def test_roundtrip_array(hstore, conn):
+ register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
+ samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+ assert samp1 == samp
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="hstore.*extension"):
+ register_hstore(None, conn) # type: ignore[arg-type]
diff --git a/tests/types/test_json.py b/tests/types/test_json.py
new file mode 100644
index 0000000..50e8ce3
--- /dev/null
+++ b/tests/types/test_json.py
@@ -0,0 +1,182 @@
+import json
+from copy import deepcopy
+
+import pytest
+
+import psycopg.types
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+from psycopg.types.json import set_json_dumps, set_json_loads
+
+samples = [
+ "null",
+ "true",
+ '"te\'xt"',
+ '"\\u00e0\\u20ac"',
+ "123",
+ "123.45",
+ '["a", 100]',
+ '{"a": 100}',
+]
+
+
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_wrapper_regtype(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ cur = conn.cursor()
+ cur.execute(
+ f"select pg_typeof(%{fmt_in.value})::regtype = %s::regtype",
+ (wrapper([]), wrapper.__name__.lower()),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump(conn, val, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = json.loads(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select %{fmt_in.value}::text = %s::{wrapper.__name__.lower()}::text",
+ (wrapper(obj), val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.crdb_skip("json array")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_array_dump(conn, val, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = json.loads(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select %{fmt_in.value}::text = array[%s::{wrapper.__name__.lower()}]::text",
+ ([wrapper(obj)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load(conn, val, jtype, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{jtype}", (val,))
+ assert cur.fetchone()[0] == json.loads(val)
+
+
+@pytest.mark.crdb_skip("json array")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_array(conn, val, jtype, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select array[%s::{jtype}]", (val,))
+ assert cur.fetchone()[0] == [json.loads(val)]
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_copy(conn, val, jtype, fmt_out):
+ cur = conn.cursor()
+ stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format(
+ val, sql.Identifier(jtype), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([jtype])
+ (got,) = copy.read_row()
+
+ assert got == json.loads(val)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur = conn.cursor()
+
+ set_json_dumps(my_dumps)
+ try:
+ cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj),))
+ assert cur.fetchone()[0] is True
+ finally:
+ set_json_dumps(json.dumps)
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise_context(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+
+ set_json_dumps(my_dumps, cur2)
+ cur1.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),))
+ assert cur1.fetchone()[0] is None
+ cur2.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),))
+ assert cur2.fetchone()[0] == "qux"
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_dump_customise_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj, my_dumps),))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("binary", [True, False])
+@pytest.mark.parametrize("pgtype", ["json", "jsonb"])
+def test_load_customise(conn, binary, pgtype):
+ cur = conn.cursor(binary=binary)
+
+ set_json_loads(my_loads)
+ try:
+ cur.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ obj = cur.fetchone()[0]
+ assert obj["foo"] == "bar"
+ assert obj["answer"] == 42
+ finally:
+ set_json_loads(json.loads)
+
+
+@pytest.mark.parametrize("binary", [True, False])
+@pytest.mark.parametrize("pgtype", ["json", "jsonb"])
+def test_load_customise_context(conn, binary, pgtype):
+ cur1 = conn.cursor(binary=binary)
+ cur2 = conn.cursor(binary=binary)
+
+ set_json_loads(my_loads, cur2)
+ cur1.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ got = cur1.fetchone()[0]
+ assert got["foo"] == "bar"
+ assert "answer" not in got
+
+ cur2.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+ got = cur2.fetchone()[0]
+ assert got["foo"] == "bar"
+ assert got["answer"] == 42
+
+
+def my_dumps(obj):
+ obj = deepcopy(obj)
+ obj["baz"] = "qux"
+ return json.dumps(obj)
+
+
+def my_loads(data):
+ obj = json.loads(data)
+ obj["answer"] = 42
+ return obj
diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py
new file mode 100644
index 0000000..2ab5152
--- /dev/null
+++ b/tests/types/test_multirange.py
@@ -0,0 +1,434 @@
+import pickle
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg import pq, sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types.range import Range
+from psycopg.types import multirange
+from psycopg.types.multirange import Multirange, MultirangeInfo
+from psycopg.types.multirange import register_multirange
+
+from ..utils import eur
+from .test_range import create_test_range
+
+pytestmark = [
+ pytest.mark.pg(">= 14"),
+ pytest.mark.crdb_skip("range"),
+]
+
+
+class TestMultirangeObject:
+ def test_empty(self):
+ mr = Multirange[int]()
+ assert not mr
+ assert len(mr) == 0
+
+ mr = Multirange([])
+ assert not mr
+ assert len(mr) == 0
+
+ def test_sequence(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ assert mr
+ assert len(mr) == 3
+ assert mr[2] == Range(50, 60)
+ assert mr[-2] == Range(30, 40)
+
+ def test_bad_type(self):
+ with pytest.raises(TypeError):
+ Multirange(Range(10, 20)) # type: ignore[arg-type]
+
+ with pytest.raises(TypeError):
+ Multirange([10]) # type: ignore[arg-type]
+
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+
+ with pytest.raises(TypeError):
+ mr[0] = "foo" # type: ignore[call-overload]
+
+ with pytest.raises(TypeError):
+ mr[0:1] = "foo" # type: ignore[assignment]
+
+ with pytest.raises(TypeError):
+ mr[0:1] = ["foo"] # type: ignore[list-item]
+
+ with pytest.raises(TypeError):
+ mr.insert(0, "foo") # type: ignore[arg-type]
+
+ def test_setitem(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ mr[1] = Range(31, 41)
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
+ def test_setitem_slice(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ mr[1:3] = [Range(31, 41), Range(51, 61)]
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)])
+
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ with pytest.raises(TypeError, match="can only assign an iterable"):
+ mr[1:3] = Range(31, 41) # type: ignore[call-overload]
+
+ mr[1:3] = [Range(31, 41)]
+ assert mr == Multirange([Range(10, 20), Range(31, 41)])
+
+ def test_delitem(self):
+ mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+ del mr[1]
+ assert mr == Multirange([Range(10, 20), Range(50, 60)])
+
+ del mr[-2]
+ assert mr == Multirange([Range(50, 60)])
+
+ def test_insert(self):
+ mr = Multirange([Range(10, 20), Range(50, 60)])
+ mr.insert(1, Range(31, 41))
+ assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
+ def test_relations(self):
+ mr1 = Multirange([Range(10, 20), Range(30, 40)])
+ mr2 = Multirange([Range(11, 20), Range(30, 40)])
+ mr3 = Multirange([Range(9, 20), Range(30, 40)])
+ assert mr1 <= mr1
+ assert not mr1 < mr1
+ assert mr1 >= mr1
+ assert not mr1 > mr1
+ assert mr1 < mr2
+ assert mr1 <= mr2
+ assert mr1 > mr3
+ assert mr1 >= mr3
+ assert mr1 != mr2
+ assert not mr1 == mr2
+
+ def test_pickling(self):
+ r = Multirange([Range(0, 4)])
+ assert pickle.loads(pickle.dumps(r)) == r
+
+ def test_str(self):
+ mr = Multirange([Range(10, 20), Range(30, 40)])
+ assert str(mr) == "{[10, 20), [30, 40)}"
+
+ def test_repr(self):
+ mr = Multirange([Range(10, 20), Range(30, 40)])
+ expected = "Multirange([Range(10, 20, '[)'), Range(30, 40, '[)')])"
+ assert repr(mr) == expected
+
+
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
+samples = [
+ ("int4multirange", [Range(None, None, "()")]),
+ ("int4multirange", [Range(10, 20), Range(30, 40)]),
+ ("int8multirange", [Range(None, None, "()")]),
+ ("int8multirange", [Range(10, 20), Range(30, 40)]),
+ (
+ "nummultirange",
+ [
+ Range(None, Decimal(-100)),
+ Range(Decimal(100), Decimal("100.123")),
+ ],
+ ),
+ (
+ "datemultirange",
+ [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))],
+ ),
+ (
+ "tsmultirange",
+ [
+ Range(
+ dt.datetime(2000, 1, 1, 00, 00),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+ )
+ ],
+ ),
+ (
+ "tstzmultirange",
+ [
+ Range(
+ dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ ),
+ Range(
+ dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ ),
+ ],
+ ),
+]
+
+mr_names = """
+ int4multirange int8multirange nummultirange
+ datemultirange tsmultirange tstzmultirange""".split()
+
+mr_classes = """
+ Int4Multirange Int8Multirange NumericMultirange
+ DateMultirange TimestampMultirange TimestamptzMultirange""".split()
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
+ mr = Multirange() # type: ignore[var-annotated]
+ cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in.value}", (mr,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+ dumper = getattr(multirange, wrapper + "Dumper")
+ wrapper = getattr(multirange, wrapper)
+ mr = wrapper()
+ rec = conn.execute(
+ f"""
+ select '{{}}' = %(mr){fmt_in.value},
+ %(mr){fmt_in.value}::text,
+ pg_typeof(%(mr){fmt_in.value})::oid
+ """,
+ {"mr": mr},
+ ).fetchone()
+ assert rec[0] is True, rec[1]
+ assert rec[2] == dumper.oid
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ PyFormat.AUTO,
+ PyFormat.TEXT,
+ # There are many ways to work around this (use text, use a cast on the
+ # placeholder, use specific Range subclasses).
+ pytest.param(
+ PyFormat.BINARY,
+ marks=pytest.mark.xfail(
+ reason="can't dump array of untypes binary multirange without cast"
+ ),
+ ),
+ ],
+)
+def test_dump_builtin_array(conn, pgtype, fmt_in):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in.value}",
+ ([mr1, mr2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"""
+ select array['{{}}'::{pgtype},
+ '{{(,)}}'::{pgtype}] = %{fmt_in.value}::{pgtype}[]
+ """,
+ ([mr1, mr2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(multirange, wrapper)
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in.value}""", ([mr1, mr2],)
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in):
+ mr = Multirange(ranges)
+ rname = pgtype.replace("multi", "")
+ phs = ", ".join([f"%s::{rname}"] * len(ranges))
+ cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in.value}", ranges + [mr])
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_empty(conn, pgtype, fmt_out):
+ mr = Multirange() # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone()
+ assert type(got) is Multirange
+ assert got == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_array(conn, pgtype, fmt_out):
+ mr1 = Multirange() # type: ignore[var-annotated]
+ mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(
+ f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]"
+ ).fetchone()
+ assert got == [mr1, mr2]
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_range(conn, pgtype, ranges, fmt_out):
+ mr = Multirange(ranges)
+ rname = pgtype.replace("multi", "")
+ phs = ", ".join([f"%s::{rname}"] * len(ranges))
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select {pgtype}({phs})", ranges)
+ assert cur.fetchone()[0] == mr
+
+
+@pytest.mark.parametrize(
+ "min, max, bounds",
+ [
+ ("2000,1,1", "2001,1,1", "[)"),
+ ("2000,1,1", None, "[)"),
+ (None, "2001,1,1", "()"),
+ (None, None, "()"),
+ (None, None, "empty"),
+ ],
+)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in(conn, min, max, bounds, format):
+ cur = conn.cursor()
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
+
+ if bounds != "empty":
+ min = dt.date(*map(int, min.split(","))) if min else None
+ max = dt.date(*map(int, max.split(","))) if max else None
+ r = Range[dt.date](min, max, bounds)
+ else:
+ r = Range(empty=True)
+
+ mr = Multirange([r])
+ try:
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.write_row([mr])
+ except e.InternalError_:
+ if not min and not max and format == pq.Format.BINARY:
+ pytest.xfail("TODO: add annotation to dump multirange with no type info")
+ else:
+ raise
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ if not r.isempty:
+ assert rec[0] == mr
+ else:
+ assert rec[0] == Multirange()
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_wrappers(conn, wrapper, format):
+ cur = conn.cursor()
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
+
+ mr = getattr(multirange, wrapper)()
+
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.write_row([mr])
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ assert rec[0] == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_set_type(conn, pgtype, format):
+ cur = conn.cursor()
+ cur.execute(f"create table copymr (id serial primary key, mr {pgtype})")
+
+ mr = Multirange() # type: ignore[var-annotated]
+
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
+ copy.set_types([pgtype])
+ copy.write_row([mr])
+
+ rec = cur.execute("select mr from copymr order by id").fetchone()
+ assert rec[0] == mr
+
+
+@pytest.fixture(scope="session")
+def testmr(svcconn):
+ create_test_range(svcconn)
+
+
+fetch_cases = [
+ ("testmultirange", "text"),
+ ("testschema.testmultirange", "float8"),
+ (sql.Identifier("testmultirange"), "text"),
+ (sql.Identifier("testschema", "testmultirange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testmr, name, subtype):
+ info = MultirangeInfo.fetch(conn, name)
+ assert info.name == "testmultirange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == conn.adapters.types[subtype].oid
+
+
+def test_fetch_info_not_found(conn):
+ assert MultirangeInfo.fetch(conn, "nosuchrange") is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testmr, name, subtype): # noqa: F811
+ info = await MultirangeInfo.fetch(aconn, name)
+ assert info.name == "testmultirange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == aconn.adapters.types[subtype].oid
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+ assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None
+
+
+def test_dump_custom_empty(conn, testmr):
+ info = MultirangeInfo.fetch(conn, "testmultirange")
+ register_multirange(info, conn)
+
+ r = Multirange() # type: ignore[var-annotated]
+ cur = conn.execute("select '{}'::testmultirange = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_custom_empty(conn, testmr, fmt_out):
+ info = MultirangeInfo.fetch(conn, "testmultirange")
+ register_multirange(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute("select '{}'::testmultirange").fetchone()
+ assert isinstance(got, Multirange)
+ assert not got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(f'create type "{name}" as range (subtype = text)')
+ info = MultirangeInfo.fetch(conn, f'"{name}_multirange"')
+ register_multirange(info, conn)
+ obj = Multirange([Range("a", "z", "[]")])
+ assert sql.Literal(obj).as_string(conn) == f"'{{[a,z]}}'::\"{name}_multirange\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ assert cur.fetchone()[0] == obj
diff --git a/tests/types/test_net.py b/tests/types/test_net.py
new file mode 100644
index 0000000..8739398
--- /dev/null
+++ b/tests/types/test_net.py
@@ -0,0 +1,135 @@
+import ipaddress
+
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+
+crdb_skip_cidr = pytest.mark.crdb_skip("cidr")
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"])
+def test_address_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::inet", (ipaddress.ip_address(val), val))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select %{fmt_in.value} = array[null, %s]::inet[]",
+ ([None, ipaddress.ip_interface(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/128"])
+def test_interface_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ rec = cur.execute(
+ f"select %(val){fmt_in.value} = %(repr)s::inet,"
+ f" %(val){fmt_in.value}, %(repr)s::inet",
+ {"val": ipaddress.ip_interface(val), "repr": val},
+ ).fetchone()
+ assert rec[0] is True, f"{rec[1]} != {rec[2]}"
+ cur.execute(
+ f"select %{fmt_in.value} = array[null, %s]::inet[]",
+ ([None, ipaddress.ip_interface(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
+def test_network_dump(conn, fmt_in, val):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::cidr", (ipaddress.ip_network(val), val))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select %{fmt_in.value} = array[NULL, %s]::cidr[]",
+ ([None, ipaddress.ip_network(val)], val),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_network_mixed_size_array(conn, fmt_in):
+ val = [
+ ipaddress.IPv4Network("192.168.0.1/32"),
+ ipaddress.IPv6Network("::1/128"),
+ ]
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}", (val,))
+ got = cur.fetchone()[0]
+ assert val == got
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
+def test_inet_load_address(conn, fmt_out, val):
+ addr = ipaddress.ip_address(val.split("/", 1)[0])
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::inet", (val,))
+ assert cur.fetchone()[0] == addr
+
+ cur.execute("select array[null, %s::inet]", (val,))
+ assert cur.fetchone()[0] == [None, addr]
+
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == addr
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"])
+def test_inet_load_network(conn, fmt_out, val):
+ pyval = ipaddress.ip_interface(val)
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::inet", (val,))
+ assert cur.fetchone()[0] == pyval
+
+ cur.execute("select array[null, %s::inet]", (val,))
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
+
+
+@crdb_skip_cidr
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
+def test_cidr_load(conn, fmt_out, val):
+ pyval = ipaddress.ip_network(val)
+ cur = conn.cursor(binary=fmt_out)
+
+ cur.execute("select %s::cidr", (val,))
+ assert cur.fetchone()[0] == pyval
+
+ cur.execute("select array[null, %s::cidr]", (val,))
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["cidr"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
diff --git a/tests/types/test_none.py b/tests/types/test_none.py
new file mode 100644
index 0000000..4c008fd
--- /dev/null
+++ b/tests/types/test_none.py
@@ -0,0 +1,12 @@
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+
+
+def test_quote_none(conn):
+
+ tx = Transformer()
+ assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL"
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+ assert cur.fetchone()[0] is None
diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py
new file mode 100644
index 0000000..a27bc84
--- /dev/null
+++ b/tests/types/test_numeric.py
@@ -0,0 +1,625 @@
+import enum
+from decimal import Decimal
+from math import isnan, isinf, exp
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import Transformer, PyFormat
+from psycopg.types.numeric import FloatLoader
+
+from ..fix_crdb import is_crdb
+
+#
+# Tests with int
+#
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::int"),
+ (1, "'1'::int"),
+ (-1, "'-1'::int"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
+ (int(2**63 - 1), "'9223372036854775807'::bigint"),
+ (int(-(2**63)), "'-9223372036854775808'::bigint"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_int(conn, val, expr, fmt_in):
+ assert isinstance(val, int)
+ cur = conn.cursor()
+ cur.execute(f"select {expr} = %{fmt_in.value}", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, "'0'::smallint"),
+ (1, "'1'::smallint"),
+ (-1, "'-1'::smallint"),
+ (42, "'42'::smallint"),
+ (-42, "'-42'::smallint"),
+ (int(2**15 - 1), f"'{2 ** 15 - 1}'::smallint"),
+ (int(-(2**15)), f"'{-2 ** 15}'::smallint"),
+ (int(2**15), f"'{2 ** 15}'::integer"),
+ (int(-(2**15) - 1), f"'{-2 ** 15 - 1}'::integer"),
+ (int(2**31 - 1), f"'{2 ** 31 - 1}'::integer"),
+ (int(-(2**31)), f"'{-2 ** 31}'::integer"),
+ (int(2**31), f"'{2 ** 31}'::bigint"),
+ (int(-(2**31) - 1), f"'{-2 ** 31 - 1}'::bigint"),
+ (int(2**63 - 1), f"'{2 ** 63 - 1}'::bigint"),
+ (int(-(2**63)), f"'{-2 ** 63}'::bigint"),
+ (int(2**63), f"'{2 ** 63}'::numeric"),
+ (int(-(2**63) - 1), f"'{-2 ** 63 - 1}'::numeric"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_int_subtypes(conn, val, expr, fmt_in):
+ cur = conn.cursor()
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,))
+ assert cur.fetchone()[0] is True
+ cur.execute(
+ f"select {expr} = %(v){fmt_in.value}, {expr}::text, %(v){fmt_in.value}::text",
+ {"v": val},
+ )
+ ok, want, got = cur.fetchone()
+ assert got == want
+ assert ok
+
+
+class MyEnum(enum.IntEnum):
+ foo = 42
+
+
+class MyMixinEnum(enum.IntEnum):
+ foo = 42000000
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("enum", [MyEnum, MyMixinEnum])
+def test_dump_enum(conn, fmt_in, enum):
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value}", (enum.foo,))
+ (res,) = cur.fetchone()
+ assert res == enum.foo.value
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, b"0"),
+ (1, b"1"),
+ (-1, b" -1"),
+ (42, b"42"),
+ (-42, b" -42"),
+ (int(2**63 - 1), b"9223372036854775807"),
+ (int(-(2**63)), b" -9223372036854775808"),
+ (int(2**63), b"9223372036854775808"),
+ (int(-(2**63 + 1)), b" -9223372036854775809"),
+ (int(2**100), b"1267650600228229401496703205376"),
+ (int(-(2**100)), b" -1267650600228229401496703205376"),
+ ],
+)
+def test_quote_int(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ assert cur.fetchone() == (val, -val)
+
+
+@pytest.mark.parametrize(
+ "val, pgtype, want",
+ [
+ ("0", "integer", 0),
+ ("1", "integer", 1),
+ ("-1", "integer", -1),
+ ("0", "int2", 0),
+ ("0", "int4", 0),
+ ("0", "int8", 0),
+ ("0", "integer", 0),
+ ("0", "oid", 0),
+ # bounds
+ ("-32768", "smallint", -32768),
+ ("+32767", "smallint", 32767),
+ ("-2147483648", "integer", -2147483648),
+ ("+2147483647", "integer", 2147483647),
+ ("-9223372036854775808", "bigint", -9223372036854775808),
+ ("9223372036854775807", "bigint", 9223372036854775807),
+ ("4294967295", "oid", 4294967295),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_int(conn, val, pgtype, want, fmt_out):
+ if pgtype == "integer" and is_crdb(conn):
+ pgtype = "int4" # "integer" is "int8" on crdb
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{pgtype}", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid
+ result = cur.fetchone()[0]
+ assert result == want
+ assert type(result) is type(want)
+
+ # arrays work too
+ cur.execute(f"select array[%s::{pgtype}]", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid
+ result = cur.fetchone()[0]
+ assert result == [want]
+ assert type(result[0]) is type(want)
+
+
+#
+# Tests with float
+#
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0.0, "'0'"),
+ (1.0, "'1'"),
+ (-1.0, "'-1'"),
+ (float("nan"), "'NaN'"),
+ (float("inf"), "'Infinity'"),
+ (float("-inf"), "'-Infinity'"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_float(conn, val, expr, fmt_in):
+ assert isinstance(val, float)
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = {expr}::float8", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0.0, b"0.0"),
+ (1.0, b"1.0"),
+ (10000000000000000.0, b"1e+16"),
+ (1000000.1, b"1000000.1"),
+ (-100000.000001, b" -100000.000001"),
+ (-1.0, b" -1.0"),
+ (float("nan"), b"'NaN'::float8"),
+ (float("inf"), b"'Infinity'::float8"),
+ (float("-inf"), b"'-Infinity'::float8"),
+ ],
+)
+def test_quote_float(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+ if isnan(val):
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ if isinstance(r[0], Decimal):
+ r = tuple(map(float, r))
+
+ assert r == (val, -val)
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (exp(1), "exp(1.0)"),
+ (-exp(1), "-exp(1.0)"),
+ (1e30, "'1e30'"),
+ (1e-30, "1e-30"),
+ (-1e30, "'-1e30'"),
+ (-1e-30, "-1e-30"),
+ ],
+)
+def test_dump_float_approx(conn, val, expr):
+ assert isinstance(val, float)
+ cur = conn.cursor()
+ cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, pgtype, want",
+ [
+ ("0", "float4", 0.0),
+ ("0.0", "float4", 0.0),
+ ("42", "float4", 42.0),
+ ("-42", "float4", -42.0),
+ ("0.0", "float8", 0.0),
+ ("0.0", "real", 0.0),
+ ("0.0", "double precision", 0.0),
+ ("0.0", "float4", 0.0),
+ ("nan", "float4", float("nan")),
+ ("inf", "float4", float("inf")),
+ ("-inf", "float4", -float("inf")),
+ ("nan", "float8", float("nan")),
+ ("inf", "float8", float("inf")),
+ ("-inf", "float8", -float("inf")),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_float(conn, val, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select %s::{pgtype}", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid
+ result = cur.fetchone()[0]
+
+ def check(result, want):
+ assert type(result) is type(want)
+ if isnan(want):
+ assert isnan(result)
+ elif isinf(want):
+ assert isinf(result)
+ assert (result < 0) is (want < 0)
+ else:
+ assert result == want
+
+ check(result, want)
+
+ cur.execute(f"select array[%s::{pgtype}]", (val,))
+ assert cur.pgresult.fformat(0) == fmt_out
+ assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid
+ result = cur.fetchone()[0]
+ assert isinstance(result, list)
+ check(result[0], want)
+
+
+@pytest.mark.parametrize(
+ "expr, pgtype, want",
+ [
+ ("exp(1.0)", "float4", 2.71828),
+ ("-exp(1.0)", "float4", -2.71828),
+ ("exp(1.0)", "float8", 2.71828182845905),
+ ("-exp(1.0)", "float8", -2.71828182845905),
+ ("1.42e10", "float4", 1.42e10),
+ ("-1.42e10", "float4", -1.42e10),
+ ("1.42e40", "float8", 1.42e40),
+ ("-1.42e40", "float8", -1.42e40),
+ ],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_float_approx(conn, expr, pgtype, want, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute("select %s::%s" % (expr, pgtype))
+ assert cur.pgresult.fformat(0) == fmt_out
+ result = cur.fetchone()[0]
+ assert result == pytest.approx(want)
+
+
+@pytest.mark.crdb_skip("copy")
+def test_load_float_copy(conn):
+ cur = conn.cursor(binary=False)
+ with cur.copy("copy (select 3.14::float8, 'hi'::text) to stdout;") as copy:
+ copy.set_types(["float8", "text"])
+ rec = copy.read_row()
+
+ assert rec[0] == pytest.approx(3.14)
+ assert rec[1] == "hi"
+
+
+#
+# Tests with decimal
+#
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "0",
+ "-0",
+ "0.0",
+ "0.000000000000000000001",
+ "-0.000000000000000000001",
+ "nan",
+ "snan",
+ ],
+)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_roundtrip_numeric(conn, val, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ val = Decimal(val)
+ cur.execute(f"select %{fmt_in.value}", (val,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, Decimal)
+ if val.is_nan():
+ assert result.is_nan()
+ else:
+ assert result == val
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("0", b"0"),
+ ("0.0", b"0.0"),
+ ("0.00000000000000001", b"1E-17"),
+ ("-0.00000000000000001", b" -1E-17"),
+ ("nan", b"'NaN'::numeric"),
+ ("snan", b"'NaN'::numeric"),
+ ],
+)
+def test_quote_numeric(conn, val, expr):
+ val = Decimal(val)
+ tx = Transformer()
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+
+ if val.is_nan():
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ assert r == (val, -val)
+
+
+@pytest.mark.crdb_skip("binary decimal")
+@pytest.mark.parametrize(
+ "expr",
+ ["NaN", "1", "1.0", "-1", "0.0", "0.01", "11", "1.1", "1.01", "0", "0.00"]
+ + [
+ "0.0000000",
+ "0.00001",
+ "1.00001",
+ "-1.00000000000000",
+ "-2.00000000000000",
+ "1000000000.12345",
+ "100.123456790000000000000000",
+ "1.0e-1000",
+ "1e1000",
+ "0.000000000000000000000000001",
+ "1.0000000000000000000000001",
+ "1000000000000000000000000.001",
+ "1000000000000000000000000000.001",
+ "9999999999999999999999999999.9",
+ ],
+)
+def test_dump_numeric_binary(conn, expr):
+ cur = conn.cursor()
+ val = Decimal(expr)
+ cur.execute("select %b::text, %s::decimal::text", [val, expr])
+ want, got = cur.fetchone()
+ assert got == want
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ f
+ if f != PyFormat.BINARY
+ else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+ for f in PyFormat
+ ],
+)
+def test_dump_numeric_exhaustive(conn, fmt_in):
+ cur = conn.cursor()
+
+ funcs = [
+ (lambda i: "1" + "0" * i),
+ (lambda i: "1" + "0" * i + "." + "0" * i),
+ (lambda i: "-1" + "0" * i),
+ (lambda i: "0." + "0" * i + "1"),
+ (lambda i: "-0." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "10"),
+ (lambda i: "1" + "0" * i + ".001"),
+ (lambda i: "9" + "9" * i),
+ (lambda i: "9" + "." + "9" * i),
+ (lambda i: "9" + "9" * i + ".9"),
+ (lambda i: "9" + "9" * i + "." + "9" * i),
+ (lambda i: "1.1e%s" % i),
+ (lambda i: "1.1e-%s" % i),
+ ]
+
+ for i in range(100):
+ for f in funcs:
+ expr = f(i)
+ val = Decimal(expr)
+ cur.execute(f"select %{fmt_in.value}::text, %s::decimal::text", [val, expr])
+ got, want = cur.fetchone()
+ assert got == want
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("inf", "Infinity"),
+ ("-inf", "-Infinity"),
+ ],
+)
+def test_dump_numeric_binary_inf(conn, val, expr):
+ cur = conn.cursor()
+ val = Decimal(val)
+ cur.execute("select %b", [val])
+
+
+@pytest.mark.parametrize(
+ "expr",
+ ["nan", "0", "1", "-1", "0.0", "0.01"]
+ + [
+ "0.0000000",
+ "-1.00000000000000",
+ "-2.00000000000000",
+ "1000000000.12345",
+ "100.123456790000000000000000",
+ "1.0e-1000",
+ "1e1000",
+ "0.000000000000000000000000001",
+ "1.0000000000000000000000001",
+ "1000000000000000000000000.001",
+ "1000000000000000000000000000.001",
+ "9999999999999999999999999999.9",
+ ],
+)
+def test_load_numeric_binary(conn, expr):
+ cur = conn.cursor(binary=1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(expr)
+ if val.is_nan():
+ assert res.is_nan()
+ else:
+ assert res == val
+ if "e" not in expr:
+ assert str(res) == str(val)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_numeric_exhaustive(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ funcs = [
+ (lambda i: "1" + "0" * i),
+ (lambda i: "1" + "0" * i + "." + "0" * i),
+ (lambda i: "-1" + "0" * i),
+ (lambda i: "0." + "0" * i + "1"),
+ (lambda i: "-0." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "1"),
+ (lambda i: "1." + "0" * i + "10"),
+ (lambda i: "1" + "0" * i + ".001"),
+ (lambda i: "9" + "9" * i),
+ (lambda i: "9" + "." + "9" * i),
+ (lambda i: "9" + "9" * i + ".9"),
+ (lambda i: "9" + "9" * i + "." + "9" * i),
+ ]
+
+ for i in range(100):
+ for f in funcs:
+ snum = f(i)
+ want = Decimal(snum)
+ got = cur.execute(f"select '{snum}'::decimal").fetchone()[0]
+ assert want == got
+ assert str(want) == str(got)
+
+
+@pytest.mark.pg(">= 14")
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("inf", "Infinity"),
+ ("-inf", "-Infinity"),
+ ],
+)
+def test_load_numeric_binary_inf(conn, val, expr):
+ cur = conn.cursor(binary=1)
+ res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
+ val = Decimal(val)
+ assert res == val
+
+
+@pytest.mark.parametrize(
+ "val",
+ [
+ "0",
+ "0.0",
+ "0.000000000000000000001",
+ "-0.000000000000000000001",
+ "nan",
+ ],
+)
+def test_numeric_as_float(conn, val):
+ cur = conn.cursor()
+ cur.adapters.register_loader("numeric", FloatLoader)
+
+ val = Decimal(val)
+ cur.execute("select %s as val", (val,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, float)
+ if val.is_nan():
+ assert isnan(result)
+ else:
+ assert result == pytest.approx(float(val))
+
+ # the customization works with arrays too
+ cur.execute("select %s as arr", ([val],))
+ result = cur.fetchone()[0]
+ assert isinstance(result, list)
+ assert isinstance(result[0], float)
+ if val.is_nan():
+ assert isnan(result[0])
+ else:
+ assert result[0] == pytest.approx(float(val))
+
+
+#
+# Mixed tests
+#
+
+
+@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
+def test_minus_minus(conn, pgtype):
+ cur = conn.cursor()
+ cast = f"::{pgtype}" if pgtype is not None else ""
+ cur.execute(f"select -%s{cast}", [-1])
+ result = cur.fetchone()[0]
+ assert result == 1
+
+
+@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
+def test_minus_minus_quote(conn, pgtype):
+ cur = conn.cursor()
+ cast = f"::{pgtype}" if pgtype is not None else ""
+ cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast)))
+ result = cur.fetchone()[0]
+ assert result == 1
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ obj = wrapper(1)
+ cur = conn.execute(
+ f"select %(obj){fmt_in.value} = 1, %(obj){fmt_in.value}", {"obj": obj}
+ )
+ rec = cur.fetchone()
+ assert rec[0], rec[1]
+
+
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+def test_dump_wrapper_oid(wrapper):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ base = wrapper.__mro__[1]
+ assert base in (int, float)
+ n = base(3.14)
+ assert str(wrapper(n)) == str(n)
+ assert repr(wrapper(n)) == f"{wrapper.__name__}({n})"
+
+
+@pytest.mark.crdb("skip", reason="all types returned as bigint? TODOCRDB")
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split())
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_repr_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.numeric, wrapper)
+ cur = conn.execute(f"select pg_typeof(%{fmt_in.value})::oid", [wrapper(0)])
+ oid = cur.fetchone()[0]
+ assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "typename",
+ "integer int2 int4 int8 float4 float8 numeric".split() + ["double precision"],
+)
+def test_oid_lookup(conn, typename, fmt_out):
+ dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out)
+ assert dumper.oid == conn.adapters.types[typename].oid
+ assert dumper.format == fmt_out
diff --git a/tests/types/test_range.py b/tests/types/test_range.py
new file mode 100644
index 0000000..1efd398
--- /dev/null
+++ b/tests/types/test_range.py
@@ -0,0 +1,677 @@
+import pickle
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg import pq, sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types import range as range_module
+from psycopg.types.range import Range, RangeInfo, register_range
+
+from ..utils import eur
+from ..fix_crdb import is_crdb, crdb_skip_message
+
+pytestmark = pytest.mark.crdb_skip("range")
+
+type2sub = {
+ "int4range": "int4",
+ "int8range": "int8",
+ "numrange": "numeric",
+ "daterange": "date",
+ "tsrange": "timestamp",
+ "tstzrange": "timestamptz",
+}
+
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
+samples = [
+ ("int4range", None, None, "()"),
+ ("int4range", 10, 20, "[]"),
+ ("int4range", -(2**31), (2**31) - 1, "[)"),
+ ("int8range", None, None, "()"),
+ ("int8range", 10, 20, "[)"),
+ ("int8range", -(2**63), (2**63) - 1, "[)"),
+ ("numrange", Decimal(-100), Decimal("100.123"), "(]"),
+ ("numrange", Decimal(100), None, "()"),
+ ("numrange", None, Decimal(100), "()"),
+ ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"),
+ (
+ "tsrange",
+ dt.datetime(2000, 1, 1, 00, 00),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+ "[]",
+ ),
+ (
+ "tstzrange",
+ dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+ dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+ "()",
+ ),
+]
+
+range_names = """
+ int4range int8range numrange daterange tsrange tstzrange
+ """.split()
+
+range_classes = """
+ Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+ """.split()
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty(conn, pgtype, fmt_in):
+ r = Range(empty=True) # type: ignore[var-annotated]
+ cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in.value}", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r = wrapper(empty=True)
+ cur = conn.execute(f"select 'empty' = %{fmt_in.value}", (r,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize(
+ "fmt_in",
+ [
+ PyFormat.AUTO,
+ PyFormat.TEXT,
+ # There are many ways to work around this (use text, use a cast on the
+ # placeholder, use specific Range subclasses).
+ pytest.param(
+ PyFormat.BINARY,
+ marks=pytest.mark.xfail(
+ reason="can't dump an array of untypes binary range without cast"
+ ),
+ ),
+ ],
+)
+def test_dump_builtin_array(conn, pgtype, fmt_in):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in.value}",
+ ([r1, r2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.execute(
+ f"select array['empty'::{pgtype}, '(,)'::{pgtype}] "
+ f"= %{fmt_in.value}::{pgtype}[]",
+ ([r1, r2],),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(range_module, wrapper)
+ r1 = wrapper(empty=True)
+ r2 = wrapper(bounds="()")
+ cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in.value}""", ([r1, r2],))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in):
+ r = Range(min, max, bounds) # type: ignore[var-annotated]
+ sub = type2sub[pgtype]
+ cur = conn.execute(
+ f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in.value}",
+ (min, max, bounds, r),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_empty(conn, pgtype, fmt_out):
+ r = Range(empty=True) # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
+ assert type(got) is Range
+ assert got == r
+ assert not got
+ assert got.isempty
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_inf(conn, pgtype, fmt_out):
+ r = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone()
+ assert type(got) is Range
+ assert got == r
+ assert got
+ assert not got.isempty
+ assert got.lower_inf
+ assert got.upper_inf
+
+
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_array(conn, pgtype, fmt_out):
+ r1 = Range(empty=True) # type: ignore[var-annotated]
+ r2 = Range(bounds="()") # type: ignore[var-annotated]
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone()
+ assert got == [r1, r2]
+
+
+@pytest.mark.parametrize("pgtype, min, max, bounds", samples)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
+ r = Range(min, max, bounds) # type: ignore[var-annotated]
+ sub = type2sub[pgtype]
+ cur = conn.cursor(binary=fmt_out)
+ cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds))
+ # normalise discrete ranges
+ if r.upper_inc and isinstance(r.upper, int):
+ bounds = "[)" if r.lower_inc else "()"
+ r = type(r)(r.lower, r.upper + 1, bounds)
+ assert cur.fetchone()[0] == r
+
+
+@pytest.mark.parametrize(
+ "min, max, bounds",
+ [
+ ("2000,1,1", "2001,1,1", "[)"),
+ ("2000,1,1", None, "[)"),
+ (None, "2001,1,1", "()"),
+ (None, None, "()"),
+ (None, None, "empty"),
+ ],
+)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in(conn, min, max, bounds, format):
+ cur = conn.cursor()
+ cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+ if bounds != "empty":
+ min = dt.date(*map(int, min.split(","))) if min else None
+ max = dt.date(*map(int, max.split(","))) if max else None
+ r = Range[dt.date](min, max, bounds)
+ else:
+ r = Range(empty=True)
+
+ try:
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.write_row([r])
+ except e.ProtocolViolation:
+ if not min and not max and format == pq.Format.BINARY:
+ pytest.xfail("TODO: add annotation to dump ranges with no type info")
+ else:
+ raise
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize("wrapper", range_classes)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
+ cur = conn.cursor()
+ cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+ cls = getattr(range_module, wrapper)
+ r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
+
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.write_row([r])
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize("pgtype", range_names)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_set_type(conn, bounds, pgtype, format):
+ cur = conn.cursor()
+ cur.execute(f"create table copyrange (id serial primary key, r {pgtype})")
+
+ if bounds == "empty":
+ r = Range(empty=True) # type: ignore[var-annotated]
+ else:
+ r = Range(None, None, bounds)
+
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
+ copy.set_types([pgtype])
+ copy.write_row([r])
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
+@pytest.fixture(scope="session")
+def testrange(svcconn):
+ create_test_range(svcconn)
+
+
+def create_test_range(conn):
+ if is_crdb(conn):
+ pytest.skip(crdb_skip_message("range"))
+
+ conn.execute(
+ """
+ create schema if not exists testschema;
+
+ drop type if exists testrange cascade;
+ drop type if exists testschema.testrange cascade;
+
+ create type testrange as range (subtype = text, collation = "C");
+ create type testschema.testrange as range (subtype = float8);
+ """
+ )
+
+
+fetch_cases = [
+ ("testrange", "text"),
+ ("testschema.testrange", "float8"),
+ (sql.Identifier("testrange"), "text"),
+ (sql.Identifier("testschema", "testrange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testrange, name, subtype):
+ info = RangeInfo.fetch(conn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == conn.adapters.types[subtype].oid
+
+
+def test_fetch_info_not_found(conn):
+ assert RangeInfo.fetch(conn, "nosuchrange") is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testrange, name, subtype):
+ info = await RangeInfo.fetch(aconn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == aconn.adapters.types[subtype].oid
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+ assert await RangeInfo.fetch(aconn, "nosuchrange") is None
+
+
+def test_dump_custom_empty(conn, testrange):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+
+ r = Range[str](empty=True)
+ cur = conn.execute("select 'empty'::testrange = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+def test_dump_quoting(conn, testrange):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+ cur = conn.cursor()
+ for i in range(1, 254):
+ cur.execute(
+ """
+ select ascii(lower(%(r)s)) = %(low)s
+ and ascii(upper(%(r)s)) = %(up)s
+ """,
+ {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1},
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_custom_empty(conn, testrange, fmt_out):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+
+ cur = conn.cursor(binary=fmt_out)
+ (got,) = cur.execute("select 'empty'::testrange").fetchone()
+ assert isinstance(got, Range)
+ assert got.isempty
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_quoting(conn, testrange, fmt_out):
+ info = RangeInfo.fetch(conn, "testrange")
+ register_range(info, conn)
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 254):
+ cur.execute(
+ "select testrange(chr(%(low)s::int), chr(%(up)s::int))",
+ {"low": i, "up": i + 1},
+ )
+ got: Range[str] = cur.fetchone()[0]
+ assert isinstance(got, Range)
+ assert got.lower and ord(got.lower) == i
+ assert got.upper and ord(got.upper) == i + 1
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_mixed_array_types(conn, fmt_out):
+ conn.execute("create table testmix (a daterange[], b tstzrange[])")
+ r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)")
+ r2 = Range(
+ dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc),
+ dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc),
+ "[)",
+ )
+ conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]])
+ got = conn.execute("select * from testmix").fetchone()
+ assert got == ([r1], [r2])
+
+
+class TestRangeObject:
+ def test_noparam(self):
+ r = Range() # type: ignore[var-annotated]
+
+ assert not r.isempty
+ assert r.lower is None
+ assert r.upper is None
+ assert r.lower_inf
+ assert r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_empty(self):
+ r = Range(empty=True) # type: ignore[var-annotated]
+
+ assert r.isempty
+ assert r.lower is None
+ assert r.upper is None
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_nobounds(self):
+ r = Range(10, 20)
+ assert r.lower == 10
+ assert r.upper == 20
+ assert not r.isempty
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert r.lower_inc
+ assert not r.upper_inc
+
+ def test_bounds(self):
+ for bounds, lower_inc, upper_inc in [
+ ("[)", True, False),
+ ("(]", False, True),
+ ("()", False, False),
+ ("[]", True, True),
+ ]:
+ r = Range(10, 20, bounds)
+ assert r.bounds == bounds
+ assert r.lower == 10
+ assert r.upper == 20
+ assert not r.isempty
+ assert not r.lower_inf
+ assert not r.upper_inf
+ assert r.lower_inc == lower_inc
+ assert r.upper_inc == upper_inc
+
+ def test_keywords(self):
+ r = Range(upper=20)
+ r.lower is None
+ r.upper == 20
+ assert not r.isempty
+ assert r.lower_inf
+ assert not r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ r = Range(lower=10, bounds="(]")
+ r.lower == 10
+ r.upper is None
+ assert not r.isempty
+ assert not r.lower_inf
+ assert r.upper_inf
+ assert not r.lower_inc
+ assert not r.upper_inc
+
+ def test_bad_bounds(self):
+ with pytest.raises(ValueError):
+ Range(bounds="(")
+ with pytest.raises(ValueError):
+ Range(bounds="[}")
+
+ def test_in(self):
+ r = Range[int](empty=True)
+ assert 10 not in r
+ assert "x" not in r # type: ignore[operator]
+
+ r = Range()
+ assert 10 in r
+
+ r = Range(lower=10, bounds="[)")
+ assert 9 not in r
+ assert 10 in r
+ assert 11 in r
+
+ r = Range(lower=10, bounds="()")
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 in r
+
+ r = Range(upper=20, bounds="()")
+ assert 19 in r
+ assert 20 not in r
+ assert 21 not in r
+
+ r = Range(upper=20, bounds="(]")
+ assert 19 in r
+ assert 20 in r
+ assert 21 not in r
+
+ r = Range(10, 20)
+ assert 9 not in r
+ assert 10 in r
+ assert 11 in r
+ assert 19 in r
+ assert 20 not in r
+ assert 21 not in r
+
+ r = Range(10, 20, "(]")
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 in r
+ assert 19 in r
+ assert 20 in r
+ assert 21 not in r
+
+ r = Range(20, 10)
+ assert 9 not in r
+ assert 10 not in r
+ assert 11 not in r
+ assert 19 not in r
+ assert 20 not in r
+ assert 21 not in r
+
+ def test_nonzero(self):
+ assert Range()
+ assert Range(10, 20)
+ assert not Range(empty=True)
+
+ def test_eq_hash(self):
+ def assert_equal(r1, r2):
+ assert r1 == r2
+ assert hash(r1) == hash(r2)
+
+ assert_equal(Range(empty=True), Range(empty=True))
+ assert_equal(Range(), Range())
+ assert_equal(Range(10, None), Range(10, None))
+ assert_equal(Range(10, 20), Range(10, 20))
+ assert_equal(Range(10, 20), Range(10, 20, "[)"))
+ assert_equal(Range(10, 20, "[]"), Range(10, 20, "[]"))
+
+ def assert_not_equal(r1, r2):
+ assert r1 != r2
+ assert hash(r1) != hash(r2)
+
+ assert_not_equal(Range(10, 20), Range(10, 21))
+ assert_not_equal(Range(10, 20), Range(11, 20))
+ assert_not_equal(Range(10, 20, "[)"), Range(10, 20, "[]"))
+
+ def test_eq_wrong_type(self):
+ assert Range(10, 20) != ()
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def test_lt_ordering(self):
+ assert Range(empty=True) < Range(0, 4)
+ assert not Range(1, 2) < Range(0, 4)
+ assert Range(0, 4) < Range(1, 2)
+ assert not Range(1, 2) < Range()
+ assert Range() < Range(1, 2)
+ assert not Range(1) < Range(upper=1)
+ assert not Range() < Range()
+ assert not Range(empty=True) < Range(empty=True)
+ assert not Range(1, 2) < Range(1, 2)
+ with pytest.raises(TypeError):
+ assert 1 < Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not Range(1, 2) < 1
+
+ def test_gt_ordering(self):
+ assert not Range(empty=True) > Range(0, 4)
+ assert Range(1, 2) > Range(0, 4)
+ assert not Range(0, 4) > Range(1, 2)
+ assert Range(1, 2) > Range()
+ assert not Range() > Range(1, 2)
+ assert Range(1) > Range(upper=1)
+ assert not Range() > Range()
+ assert not Range(empty=True) > Range(empty=True)
+ assert not Range(1, 2) > Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not 1 > Range(1, 2)
+ with pytest.raises(TypeError):
+ assert Range(1, 2) > 1
+
+ def test_le_ordering(self):
+ assert Range(empty=True) <= Range(0, 4)
+ assert not Range(1, 2) <= Range(0, 4)
+ assert Range(0, 4) <= Range(1, 2)
+ assert not Range(1, 2) <= Range()
+ assert Range() <= Range(1, 2)
+ assert not Range(1) <= Range(upper=1)
+ assert Range() <= Range()
+ assert Range(empty=True) <= Range(empty=True)
+ assert Range(1, 2) <= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert 1 <= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not Range(1, 2) <= 1
+
+ def test_ge_ordering(self):
+ assert not Range(empty=True) >= Range(0, 4)
+ assert Range(1, 2) >= Range(0, 4)
+ assert not Range(0, 4) >= Range(1, 2)
+ assert Range(1, 2) >= Range()
+ assert not Range() >= Range(1, 2)
+ assert Range(1) >= Range(upper=1)
+ assert Range() >= Range()
+ assert Range(empty=True) >= Range(empty=True)
+ assert Range(1, 2) >= Range(1, 2)
+ with pytest.raises(TypeError):
+ assert not 1 >= Range(1, 2)
+ with pytest.raises(TypeError):
+ (Range(1, 2) >= 1)
+
+ def test_pickling(self):
+ r = Range(0, 4)
+ assert pickle.loads(pickle.dumps(r)) == r
+
+ def test_str(self):
+ """
+ Range types should have a short and readable ``str`` implementation.
+ """
+ expected = [
+ "(0, 4)",
+ "[0, 4]",
+ "(0, 4]",
+ "[0, 4)",
+ "empty",
+ ]
+ results = []
+
+ for bounds in ("()", "[]", "(]", "[)"):
+ r = Range(0, 4, bounds=bounds)
+ results.append(str(r))
+
+ r = Range(empty=True)
+ results.append(str(r))
+ assert results == expected
+
+ def test_str_datetime(self):
+ """
+ Date-Time ranges should return a human-readable string as well on
+ string conversion.
+ """
+ tz = dt.timezone(dt.timedelta(hours=-5))
+ r = Range(
+ dt.datetime(2010, 1, 1, tzinfo=tz),
+ dt.datetime(2011, 1, 1, tzinfo=tz),
+ )
+ expected = "[2010-01-01 00:00:00-05:00, 2011-01-01 00:00:00-05:00)"
+ result = str(r)
+ assert result == expected
+
+ def test_exclude_inf_bounds(self):
+ r = Range(None, 10, "[]")
+ assert r.lower is None
+ assert not r.lower_inc
+ assert r.bounds == "(]"
+
+ r = Range(10, None, "[]")
+ assert r.upper is None
+ assert not r.upper_inc
+ assert r.bounds == "[)"
+
+ r = Range(None, None, "[]")
+ assert r.lower is None
+ assert not r.lower_inc
+ assert r.upper is None
+ assert not r.upper_inc
+ assert r.bounds == "()"
+
+
+def test_no_info_error(conn):
+ with pytest.raises(TypeError, match="range"):
+ register_range(None, conn) # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+def test_literal_invalid_name(conn, name):
+ conn.execute("set client_encoding to utf8")
+ conn.execute(f'create type "{name}" as range (subtype = text)')
+ info = RangeInfo.fetch(conn, f'"{name}"')
+ register_range(info, conn)
+ obj = Range("a", "z", "[]")
+ assert sql.Literal(obj).as_string(conn) == f"'[a,z]'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format(obj))
+ assert cur.fetchone()[0] == obj
diff --git a/tests/types/test_shapely.py b/tests/types/test_shapely.py
new file mode 100644
index 0000000..0f7007e
--- /dev/null
+++ b/tests/types/test_shapely.py
@@ -0,0 +1,152 @@
+import pytest
+
+import psycopg
+from psycopg.pq import Format
+from psycopg.types import TypeInfo
+from psycopg.adapt import PyFormat
+
+pytest.importorskip("shapely")
+
+from shapely.geometry import Point, Polygon, MultiPolygon # noqa: E402
+from psycopg.types.shapely import register_shapely # noqa: E402
+
+pytestmark = [
+ pytest.mark.postgis,
+ pytest.mark.crdb("skip"),
+]
+
+# real example, with CRS and "holes"
+MULTIPOLYGON_GEOJSON = """
+{
+ "type":"MultiPolygon",
+ "crs":{
+ "type":"name",
+ "properties":{
+ "name":"EPSG:3857"
+ }
+ },
+ "coordinates":[
+ [
+ [
+ [89574.61111389, 6894228.638802719],
+ [89576.815239808, 6894208.60747024],
+ [89576.904295401, 6894207.820852726],
+ [89577.99522641, 6894208.022080451],
+ [89577.961830563, 6894209.229446936],
+ [89589.227363031, 6894210.601454523],
+ [89594.615226386, 6894161.849595264],
+ [89600.314784314, 6894111.37846976],
+ [89651.187791607, 6894116.774968589],
+ [89648.49385993, 6894140.226914071],
+ [89642.92788539, 6894193.423936413],
+ [89639.721884055, 6894224.08372821],
+ [89589.283022777, 6894218.431048969],
+ [89588.192091767, 6894230.248628867],
+ [89574.61111389, 6894228.638802719]
+ ],
+ [
+ [89610.344670435, 6894182.466199101],
+ [89625.985058891, 6894184.258949757],
+ [89629.547282597, 6894153.270030369],
+ [89613.918026089, 6894151.458993318],
+ [89610.344670435, 6894182.466199101]
+ ]
+ ]
+ ]
+}"""
+
+SAMPLE_POINT_GEOJSON = '{"type":"Point","coordinates":[1.2, 3.4]}'
+
+
+@pytest.fixture
+def shapely_conn(conn, svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists postgis")
+ except psycopg.Error as e:
+ pytest.skip(f"can't create extension postgis: {e}")
+
+ info = TypeInfo.fetch(conn, "geometry")
+ assert info
+ register_shapely(info, conn)
+ return conn
+
+
+def test_no_adapter(conn):
+ point = Point(1.2, 3.4)
+ with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"):
+ conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0]
+
+
+def test_no_info_error(conn):
+ from psycopg.types.shapely import register_shapely
+
+ with pytest.raises(TypeError, match="postgis.*extension"):
+ register_shapely(None, conn) # type: ignore[arg-type]
+
+
+def test_with_adapter(shapely_conn):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
+
+ assert (
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0]
+ == "geometry"
+ )
+
+ assert (
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0]
+ == "geometry"
+ )
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", Format)
+def test_write_read_shape(shapely_conn, fmt_in, fmt_out):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
+
+ with shapely_conn.cursor(binary=fmt_out) as cur:
+ cur.execute(
+ """
+ create table sample_geoms(
+ id INTEGER PRIMARY KEY,
+ geom geometry
+ )
+ """
+ )
+ cur.execute(
+ f"insert into sample_geoms(id, geom) VALUES(1, %{fmt_in})",
+ (SAMPLE_POINT,),
+ )
+ cur.execute(
+ f"insert into sample_geoms(id, geom) VALUES(2, %{fmt_in})",
+ (SAMPLE_POLYGON,),
+ )
+
+ cur.execute("select geom from sample_geoms where id=1")
+ result = cur.fetchone()[0]
+ assert result == SAMPLE_POINT
+
+ cur.execute("select geom from sample_geoms where id=2")
+ result = cur.fetchone()[0]
+ assert result == SAMPLE_POLYGON
+
+
+@pytest.mark.parametrize("fmt_out", Format)
+def test_match_geojson(shapely_conn, fmt_out):
+ SAMPLE_POINT = Point(1.2, 3.4)
+ with shapely_conn.cursor(binary=fmt_out) as cur:
+ cur.execute(
+ """
+ select ST_GeomFromGeoJSON(%s)
+ """,
+ (SAMPLE_POINT_GEOJSON,),
+ )
+ result = cur.fetchone()[0]
+ # clone the coordinates to have a list instead of a shapely wrapper
+ assert result.coords[:] == SAMPLE_POINT.coords[:]
+ #
+ cur.execute("select ST_GeomFromGeoJSON(%s)", (MULTIPOLYGON_GEOJSON,))
+ result = cur.fetchone()[0]
+ assert isinstance(result, MultiPolygon)
diff --git a/tests/types/test_string.py b/tests/types/test_string.py
new file mode 100644
index 0000000..d23e5e0
--- /dev/null
+++ b/tests/types/test_string.py
@@ -0,0 +1,307 @@
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.adapt import PyFormat
+from psycopg import Binary
+
+from ..utils import eur
+from ..fix_crdb import crdb_encoding, crdb_scs_off
+
+#
+# tests with text
+#
+
+
+def crdb_bpchar(*args):
+ return pytest.param(*args, marks=pytest.mark.crdb("skip", reason="bpchar"))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_1char(conn, fmt_in):
+ cur = conn.cursor()
+ for i in range(1, 256):
+ cur.execute(f"select %{fmt_in.value} = chr(%s)", (chr(i), i))
+ assert cur.fetchone()[0] is True, chr(i)
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+def test_quote_1char(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
+
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = chr(%s)")
+ for i in range(1, 256):
+ if chr(i) == "%":
+ continue
+ cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
+ assert cur.fetchone()[0] is True, chr(i)
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+@pytest.mark.crdb("skip", reason="can deal with 0 strings")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_zero(conn, fmt_in):
+ cur = conn.cursor()
+ s = "foo\x00bar"
+ with pytest.raises(psycopg.DataError):
+ cur.execute(f"select %{fmt_in.value}::text", (s,))
+
+
+def test_quote_zero(conn):
+ cur = conn.cursor()
+ s = "foo\x00bar"
+ with pytest.raises(psycopg.DataError):
+ cur.execute(sql.SQL("select {}").format(sql.Literal(s)))
+
+
+# the only way to make this pass is to reduce %% -> % every time
+# not only when there are query arguments
+# see https://github.com/psycopg/psycopg2/issues/825
+@pytest.mark.xfail
+def test_quote_percent(conn):
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%")))
+ assert cur.fetchone()[0] == "%"
+
+ cur.execute(
+ sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")),
+ (ord("%"),),
+ )
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "typename", ["text", "varchar", "name", crdb_bpchar("bpchar"), '"char"']
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_1char(conn, typename, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(1, 256):
+ if typename == '"char"' and i > 127:
+ # for char > 128 the client receives only 194 or 195.
+ continue
+
+ cur.execute(f"select chr(%s)::{typename}", (i,))
+ res = cur.fetchone()[0]
+ assert res == chr(i)
+
+ assert cur.pgresult.fformat(0) == fmt_out
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize(
+ "encoding", ["utf8", crdb_encoding("latin9"), crdb_encoding("sql_ascii")]
+)
+def test_dump_enc(conn, fmt_in, encoding):
+ cur = conn.cursor()
+
+ conn.execute(f"set client_encoding to {encoding}")
+ (res,) = cur.execute(f"select ascii(%{fmt_in.value})", (eur,)).fetchone()
+ assert res == ord(eur)
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_badenc(conn, fmt_in):
+ cur = conn.cursor()
+
+ conn.execute("set client_encoding to latin1")
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute(f"select %{fmt_in.value}::bytea", (eur,))
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_utf8_badenc(conn, fmt_in):
+ cur = conn.cursor()
+
+ conn.execute("set client_encoding to utf8")
+ with pytest.raises(UnicodeEncodeError):
+ cur.execute(f"select %{fmt_in.value}", ("\uddf8",))
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_enum(conn, fmt_in):
+ from enum import Enum
+
+ class MyEnum(str, Enum):
+ foo = "foo"
+ bar = "bar"
+
+ cur = conn.cursor()
+ cur.execute("create type myenum as enum ('foo', 'bar')")
+ cur.execute("create table with_enum (e myenum)")
+ cur.execute(f"insert into with_enum (e) values (%{fmt_in.value})", (MyEnum.foo,))
+ (res,) = cur.execute("select e from with_enum").fetchone()
+ assert res == "foo"
+
+
+@pytest.mark.crdb("skip")
+@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT])
+def test_dump_text_oid(conn, fmt_in):
+ conn.autocommit = True
+
+ with pytest.raises(e.IndeterminateDatatype):
+ conn.execute(f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"])
+ conn.adapters.register_dumper(str, psycopg.types.string.StrDumper)
+ cur = conn.execute(
+ f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"]
+ )
+ assert cur.fetchone()[0] == "foobar"
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_enc(conn, typename, encoding, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute(f"set client_encoding to {encoding}")
+ (res,) = cur.execute(f"select chr(%s)::{typename}", [ord(eur)]).fetchone()
+ assert res == eur
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
+ assert res == eur
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_badenc(conn, typename, fmt_out):
+ conn.autocommit = True
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute("set client_encoding to latin1")
+ with pytest.raises(psycopg.DataError):
+ cur.execute(f"select chr(%s)::{typename}", [ord(eur)])
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ with pytest.raises(psycopg.DataError):
+ copy.read_row()
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
+def test_load_ascii(conn, typename, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+
+ conn.execute("set client_encoding to sql_ascii")
+ cur.execute(f"select chr(%s)::{typename}", [ord(eur)])
+ assert cur.fetchone()[0] == eur.encode()
+
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
+ assert res == eur.encode()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", crdb_bpchar("bpchar")])
+def test_text_array(conn, typename, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ a = list(map(chr, range(1, 256))) + [eur]
+
+ (res,) = cur.execute(f"select %{fmt_in.value}::{typename}[]", (a,)).fetchone()
+ assert res == a
+
+
+@pytest.mark.crdb_skip("encoding")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_text_array_ascii(conn, fmt_in, fmt_out):
+ conn.execute("set client_encoding to sql_ascii")
+ cur = conn.cursor(binary=fmt_out)
+ a = list(map(chr, range(1, 256))) + [eur]
+ exp = [s.encode() for s in a]
+ (res,) = cur.execute(f"select %{fmt_in.value}::text[]", (a,)).fetchone()
+ assert res == exp
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name"])
+def test_oid_lookup(conn, typename, fmt_out):
+ dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out)
+ assert dumper.oid == conn.adapters.types[typename].oid
+ assert dumper.format == fmt_out
+
+
+#
+# tests with bytea
+#
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_dump_1byte(conn, fmt_in, pytype):
+ cur = conn.cursor()
+ for i in range(0, 256):
+ obj = pytype(bytes([i]))
+ cur.execute(f"select %{fmt_in.value} = set_byte('x', 0, %s)", (obj, i))
+ assert cur.fetchone()[0] is True, i
+
+ cur.execute(f"select %{fmt_in.value} = array[set_byte('x', 0, %s)]", ([obj], i))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")])
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_quote_1byte(conn, scs, pytype):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
+
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = set_byte('x', 0, %s)")
+ for i in range(0, 256):
+ obj = pytype(bytes([i]))
+ cur.execute(query.format(ch=sql.Literal(obj)), (i,))
+ assert cur.fetchone()[0] is True, i
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_1byte(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for i in range(0, 256):
+ cur.execute("select set_byte('x', 0, %s)", (i,))
+ val = cur.fetchone()[0]
+ assert val == bytes([i])
+
+ assert isinstance(val, bytes)
+ assert cur.pgresult.fformat(0) == fmt_out
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_bytea_array(conn, fmt_in, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ a = [bytes(range(0, 256))]
+ (res,) = cur.execute(f"select %{fmt_in.value}::bytea[]", (a,)).fetchone()
+ assert res == a
diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py
new file mode 100644
index 0000000..f86f066
--- /dev/null
+++ b/tests/types/test_uuid.py
@@ -0,0 +1,56 @@
+import sys
+from uuid import UUID
+import subprocess as sp
+
+import pytest
+
+from psycopg import pq
+from psycopg import sql
+from psycopg.adapt import PyFormat
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_uuid_dump(conn, fmt_in):
+ val = "12345678123456781234567812345679"
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in.value} = %s::uuid", (UUID(val), val))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.crdb_skip("copy")
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_uuid_load(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ val = "12345678123456781234567812345679"
+ cur.execute("select %s::uuid", (val,))
+ assert cur.fetchone()[0] == UUID(val)
+
+ stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["uuid"])
+ (res,) = copy.read_row()
+
+ assert res == UUID(val)
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+def test_lazy_load(dsn):
+ script = f"""\
+import sys
+import psycopg
+
+assert 'uuid' not in sys.modules
+
+conn = psycopg.connect({dsn!r})
+with conn.cursor() as cur:
+ cur.execute("select repeat('1', 32)::uuid")
+ cur.fetchone()
+
+conn.close()
+assert 'uuid' in sys.modules
+"""
+
+ sp.check_call([sys.executable, "-c", script])
diff --git a/tests/typing_example.py b/tests/typing_example.py
new file mode 100644
index 0000000..a26ca49
--- /dev/null
+++ b/tests/typing_example.py
@@ -0,0 +1,176 @@
+# flake8: builtins=reveal_type
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
+
+from psycopg import Connection, Cursor, ServerCursor, connect, rows
+from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor
+
+
+def int_row_factory(
+ cursor: Union[Cursor[Any], AsyncCursor[Any]]
+) -> Callable[[Sequence[int]], int]:
+ return lambda values: values[0] if values else 42
+
+
+@dataclass
+class Person:
+ name: str
+ address: str
+
+ @classmethod
+ def row_factory(
+ cls, cursor: Union[Cursor[Any], AsyncCursor[Any]]
+ ) -> Callable[[Sequence[str]], Person]:
+ def mkrow(values: Sequence[str]) -> Person:
+ name, address = values
+ return cls(name, address)
+
+ return mkrow
+
+
+def kwargsf(*, foo: int, bar: int, baz: int) -> int:
+ return 42
+
+
+def argsf(foo: int, bar: int, baz: int) -> float:
+ return 42.0
+
+
+def check_row_factory_cursor() -> None:
+ """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
+ conn = connect()
+
+ cur1: Cursor[Any]
+ cur1 = conn.cursor()
+ r1: Optional[Any]
+ r1 = cur1.fetchone()
+ r1 is not None
+
+ cur2: Cursor[int]
+ r2: Optional[int]
+ with conn.cursor(row_factory=int_row_factory) as cur2:
+ cur2.execute("select 1")
+ r2 = cur2.fetchone()
+ r2 and r2 > 0
+
+ cur3: ServerCursor[Person]
+ persons: Sequence[Person]
+ with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
+ cur3.execute("select * from persons where name like 'al%'")
+ persons = cur3.fetchall()
+ persons[0].address
+
+
+async def async_check_row_factory_cursor() -> None:
+ """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
+ conn = await AsyncConnection.connect()
+
+ cur1: AsyncCursor[Any]
+ cur1 = conn.cursor()
+ r1: Optional[Any]
+ r1 = await cur1.fetchone()
+ r1 is not None
+
+ cur2: AsyncCursor[int]
+ r2: Optional[int]
+ async with conn.cursor(row_factory=int_row_factory) as cur2:
+ await cur2.execute("select 1")
+ r2 = await cur2.fetchone()
+ r2 and r2 > 0
+
+ cur3: AsyncServerCursor[Person]
+ persons: Sequence[Person]
+ async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
+ await cur3.execute("select * from persons where name like 'al%'")
+ persons = await cur3.fetchall()
+ persons[0].address
+
+
+def check_row_factory_connection() -> None:
+ """Type-check connect(..., row_factory=<MyRowFactory>) or
+ Connection.row_factory cases.
+ """
+ conn1: Connection[int]
+ cur1: Cursor[int]
+ r1: Optional[int]
+ conn1 = connect(row_factory=int_row_factory)
+ cur1 = conn1.execute("select 1")
+ r1 = cur1.fetchone()
+ r1 != 0
+ with conn1.cursor() as cur1:
+ cur1.execute("select 2")
+
+ conn2: Connection[Person]
+ cur2: Cursor[Person]
+ r2: Optional[Person]
+ conn2 = connect(row_factory=Person.row_factory)
+ cur2 = conn2.execute("select * from persons")
+ r2 = cur2.fetchone()
+ r2 and r2.name
+ with conn2.cursor() as cur2:
+ cur2.execute("select 2")
+
+ cur3: Cursor[Tuple[Any, ...]]
+ r3: Optional[Tuple[Any, ...]]
+ conn3 = connect()
+ cur3 = conn3.execute("select 3")
+ with conn3.cursor() as cur3:
+ cur3.execute("select 42")
+ r3 = cur3.fetchone()
+ r3 and len(r3)
+
+
+async def async_check_row_factory_connection() -> None:
+ """Type-check connect(..., row_factory=<MyRowFactory>) or
+ Connection.row_factory cases.
+ """
+ conn1: AsyncConnection[int]
+ cur1: AsyncCursor[int]
+ r1: Optional[int]
+ conn1 = await AsyncConnection.connect(row_factory=int_row_factory)
+ cur1 = await conn1.execute("select 1")
+ r1 = await cur1.fetchone()
+ r1 != 0
+ async with conn1.cursor() as cur1:
+ await cur1.execute("select 2")
+
+ conn2: AsyncConnection[Person]
+ cur2: AsyncCursor[Person]
+ r2: Optional[Person]
+ conn2 = await AsyncConnection.connect(row_factory=Person.row_factory)
+ cur2 = await conn2.execute("select * from persons")
+ r2 = await cur2.fetchone()
+ r2 and r2.name
+ async with conn2.cursor() as cur2:
+ await cur2.execute("select 2")
+
+ cur3: AsyncCursor[Tuple[Any, ...]]
+ r3: Optional[Tuple[Any, ...]]
+ conn3 = await AsyncConnection.connect()
+ cur3 = await conn3.execute("select 3")
+ async with conn3.cursor() as cur3:
+ await cur3.execute("select 42")
+ r3 = await cur3.fetchone()
+ r3 and len(r3)
+
+
+def check_row_factories() -> None:
+ conn1 = connect(row_factory=rows.tuple_row)
+ v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0]
+
+ conn2 = connect(row_factory=rows.dict_row)
+ v2: Dict[str, Any] = conn2.execute("").fetchall()[0]
+
+ conn3 = connect(row_factory=rows.class_row(Person))
+ v3: Person = conn3.execute("").fetchall()[0]
+
+ conn4 = connect(row_factory=rows.args_row(argsf))
+ v4: float = conn4.execute("").fetchall()[0]
+
+ conn5 = connect(row_factory=rows.kwargs_row(kwargsf))
+ v5: int = conn5.execute("").fetchall()[0]
+
+ v1, v2, v3, v4, v5
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..871f65d
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,179 @@
+import gc
+import re
+import sys
+import operator
+from typing import Callable, Optional, Tuple
+
+import pytest
+
+eur = "\u20ac"
+
+
+def check_libpq_version(got, want):
+ """
+ Verify if the libpq version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.libpq(">= 12")
+
+ and skips the test if the requested version doesn't match what's loaded.
+ """
+ return check_version(got, want, "libpq", postgres_rule=True)
+
+
+def check_postgres_version(got, want):
+ """
+ Verify if the server version is a version accepted.
+
+ This function is called on the tests marked with something like::
+
+ @pytest.mark.pg(">= 12")
+
+ and skips the test if the server version doesn't match what expected.
+ """
+ return check_version(got, want, "PostgreSQL", postgres_rule=True)
+
+
+def check_version(got, want, whose_version, postgres_rule=True):
+ pred = VersionCheck.parse(want, postgres_rule=postgres_rule)
+ pred.whose = whose_version
+ return pred.get_skip_message(got)
+
+
+class VersionCheck:
+ """
+ Helper to compare a version number with a test spec.
+ """
+
+ def __init__(
+ self,
+ *,
+ skip: bool = False,
+ op: Optional[str] = None,
+ version_tuple: Tuple[int, ...] = (),
+ whose: str = "(wanted)",
+ postgres_rule: bool = False,
+ ):
+ self.skip = skip
+ self.op = op or "=="
+ self.version_tuple = version_tuple
+ self.whose = whose
+ # Treat 10.1 as 10.0.1
+ self.postgres_rule = postgres_rule
+
+ @classmethod
+ def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck":
+ # Parse a spec like "> 9.6", "skip < 21.2.0"
+ m = re.match(
+ r"""(?ix)
+ ^\s* (skip|only)?
+ \s* (==|!=|>=|<=|>|<)?
+ \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?
+ \s* $
+ """,
+ spec,
+ )
+ if m is None:
+ pytest.fail(f"bad wanted version spec: {spec}")
+
+ skip = (m.group(1) or "only").lower() == "skip"
+ op = m.group(2)
+ version_tuple = tuple(int(n) for n in m.groups()[2:] if n)
+
+ return cls(
+ skip=skip, op=op, version_tuple=version_tuple, postgres_rule=postgres_rule
+ )
+
+ def get_skip_message(self, version: Optional[int]) -> Optional[str]:
+ got_tuple = self._parse_int_version(version)
+
+ msg: Optional[str] = None
+ if self.skip:
+ if got_tuple:
+ if not self.version_tuple:
+ msg = f"skip on {self.whose}"
+ elif self._match_version(got_tuple):
+ msg = (
+ f"skip on {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ if not got_tuple:
+ msg = f"only for {self.whose}"
+ elif not self._match_version(got_tuple):
+ if self.version_tuple:
+ msg = (
+ f"only for {self.whose} {self.op}"
+ f" {'.'.join(map(str, self.version_tuple))}"
+ )
+ else:
+ msg = f"only for {self.whose}"
+
+ return msg
+
+ _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq", "!=": "ne"}
+
+ def _match_version(self, got_tuple: Tuple[int, ...]) -> bool:
+ if not self.version_tuple:
+ return True
+
+ version_tuple = self.version_tuple
+ if self.postgres_rule and version_tuple and version_tuple[0] >= 10:
+ assert len(version_tuple) <= 2
+ version_tuple = version_tuple[:1] + (0,) + version_tuple[1:]
+
+ op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool]
+ op = getattr(operator, self._OP_NAMES[self.op])
+ return op(got_tuple, version_tuple)
+
+ def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]:
+ if version is None:
+ return ()
+ version, ver_fix = divmod(version, 100)
+ ver_maj, ver_min = divmod(version, 100)
+ return (ver_maj, ver_min, ver_fix)
+
+
+def gc_collect():
+ """
+ gc.collect(), but more insisting.
+ """
+ for i in range(3):
+ gc.collect()
+
+
+NO_COUNT_TYPES: Tuple[type, ...] = ()
+
+if sys.version_info[:2] == (3, 10):
+ # On my laptop there are occasional creations of a single one of these objects
+ # with empty content, which might be some Decimal caching.
+ # Keeping the guard as strict as possible, to be extended if other types
+ # or versions are necessary.
+ try:
+ from _contextvars import Context # type: ignore
+ except ImportError:
+ pass
+ else:
+ NO_COUNT_TYPES += (Context,)
+
+
+def gc_count() -> int:
+ """
+ len(gc.get_objects()), with subtleties.
+ """
+ if not NO_COUNT_TYPES:
+ return len(gc.get_objects())
+
+ # Note: not using a list comprehension because it pollutes the objects list.
+ rv = 0
+ for obj in gc.get_objects():
+ if isinstance(obj, NO_COUNT_TYPES):
+ continue
+ rv += 1
+
+ return rv
+
+
+async def alist(it):
+ return [i async for i in it]