diff options
Diffstat (limited to 'tests')
103 files changed, 32033 insertions, 0 deletions
diff --git a/tests/README.rst b/tests/README.rst new file mode 100644 index 0000000..63c7238 --- /dev/null +++ b/tests/README.rst @@ -0,0 +1,94 @@ +psycopg test suite +=================== + +Quick version +------------- + +To run tests on the current code you can install the `test` extra of the +package, specify a connection string in the `PSYCOPG_TEST_DSN` env var to +connect to a test database, and run ``pytest``:: + + $ pip install -e "psycopg[test]" + $ export PSYCOPG_TEST_DSN="host=localhost dbname=psycopg_test" + $ pytest + + +Test options +------------ + +- The tests output header shows additional psycopg related information, + on top of the one normally displayed by ``pytest`` and the extensions used:: + + $ pytest + ========================= test session starts ========================= + platform linux -- Python 3.8.5, pytest-6.0.2, py-1.10.0, pluggy-0.13.1 + Using --randomly-seed=2416596601 + libpq available: 130002 + libpq wrapper implementation: c + + +- By default the tests run using the ``pq`` implementation that psycopg would + choose (the C module if installed, else the Python one). In order to test a + different implementation, use the normal `pq module selection mechanism`__ + of the ``PSYCOPG_IMPL`` env var:: + + $ PSYCOPG_IMPL=python pytest + ========================= test session starts ========================= + [...] + libpq available: 130002 + libpq wrapper implementation: python + + .. __: https://www.psycopg.org/psycopg/docs/api/pq.html#pq-module-implementations + + +- Slow tests have a ``slow`` marker which can be selected to reduce test + runtime to a few seconds only. Please add a ``@pytest.mark.slow`` marker to + any test needing an arbitrary wait. At the time of writing:: + + $ pytest + ========================= test session starts ========================= + [...] + ======= 1983 passed, 3 skipped, 110 xfailed in 78.21s (0:01:18) ======= + + $ pytest -m "not slow" + ========================= test session starts ========================= + [...] + ==== 1877 passed, 2 skipped, 169 deselected, 48 xfailed in 13.47s ===== + +- ``pytest`` option ``--pq-trace={TRACEFILE,STDERR}`` can be used to capture + libpq trace. When using ``stderr``, the output will only be shown for + failing or in-error tests, unless ``-s/--capture=no`` option is used. + +- ``pytest`` option ``--pq-debug`` can be used to log access to libpq's + ``PGconn`` functions. + + +Testing in docker +----------------- + +Useful to test different Python versions without installing them. Can be used +to replicate GitHub actions failures, specifying the ``--randomly-seed`` used +in the test run. The following ``PG*`` env vars are an example to adjust the +test dsn in order to connect to a database running on the docker host: specify +a set of env vars working for your setup:: + + $ docker run -ti --rm --volume `pwd`:/src --workdir /src \ + -e PSYCOPG_TEST_DSN -e PGHOST=172.17.0.1 -e PGUSER=`whoami` \ + python:3.7 bash + + # pip install -e "./psycopg[test]" ./psycopg_pool ./psycopg_c + # pytest + + +Testing with CockroachDB +======================== + +You can run CRDB in a docker container using:: + + docker run -p 26257:26257 --name crdb --rm \ + cockroachdb/cockroach:v22.1.3 start-single-node --insecure + +And use the following connection string to run the tests:: + + export PSYCOPG_TEST_DSN="host=localhost port=26257 user=root dbname=defaultdb" + pytest ... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/adapters_example.py b/tests/adapters_example.py new file mode 100644 index 0000000..a184e6a --- /dev/null +++ b/tests/adapters_example.py @@ -0,0 +1,45 @@ +from typing import Optional + +from psycopg import pq +from psycopg.abc import Dumper, Loader, AdaptContext, PyFormat, Buffer + + +def f() -> None: + d: Dumper = MyStrDumper(str, None) + assert d.dump("abc") == b"abcabc" + assert d.quote("abc") == b"'abcabc'" + + lo: Loader = MyTextLoader(0, None) + assert lo.load(b"abc") == "abcabc" + + +class MyStrDumper: + format = pq.Format.TEXT + oid = 25 # text + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + self._cls = cls + + def dump(self, obj: str) -> bytes: + return (obj * 2).encode() + + def quote(self, obj: str) -> bytes: + value = self.dump(obj) + esc = pq.Escaping() + return b"'%s'" % esc.escape_string(value.replace(b"h", b"q")) + + def get_key(self, obj: str, format: PyFormat) -> type: + return self._cls + + def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper": + return self + + +class MyTextLoader: + format = pq.Format.TEXT + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + pass + + def load(self, data: Buffer) -> str: + return (bytes(data) * 2).decode() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..15bcf40 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import sys +import asyncio +import selectors +from typing import List + +pytest_plugins = ( + "tests.fix_db", + "tests.fix_pq", + "tests.fix_mypy", + "tests.fix_faker", + "tests.fix_proxy", + "tests.fix_psycopg", + "tests.fix_crdb", + "tests.pool.fix_pool", +) + + +def pytest_configure(config): + markers = [ + "slow: this test is kinda slow (skip with -m 'not slow')", + "flakey(reason): this test may fail unpredictably')", + # There are troubles on travis with these kind of tests and I cannot + # catch the exception for my life. + "subprocess: the test import psycopg after subprocess", + "timing: the test is timing based and can fail on cheese hardware", + "dns: the test requires dnspython to run", + "postgis: the test requires the PostGIS extension to run", + ] + + for marker in markers: + config.addinivalue_line("markers", marker) + + +def pytest_addoption(parser): + parser.addoption( + "--loop", + choices=["default", "uvloop"], + default="default", + help="The asyncio loop to use for async tests.", + ) + + +def pytest_report_header(config): + rv = [] + + rv.append(f"default selector: {selectors.DefaultSelector.__name__}") + loop = config.getoption("--loop") + if loop != "default": + rv.append(f"asyncio loop: {loop}") + + return rv + + +def pytest_sessionstart(session): + # Detect if there was a segfault in the previous run. + # + # In case of segfault, pytest doesn't get a chance to write failed tests + # in the cache. As a consequence, retries would find no test failed and + # assume that all tests passed in the previous run, making the whole test pass. + cache = session.config.cache + if cache.get("segfault", False): + session.warn(Warning("Previous run resulted in segfault! Not running any test")) + session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)")) + raise session.Failed + cache.set("segfault", True) + + # Configure the async loop. + loop = session.config.getoption("--loop") + if loop == "uvloop": + import uvloop + + uvloop.install() + else: + assert loop == "default" + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +allow_fail_messages: List[str] = [] + + +def pytest_sessionfinish(session, exitstatus): + # Mark the test run successful (in the sense -weak- that we didn't segfault). + session.config.cache.set("segfault", False) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + if allow_fail_messages: + terminalreporter.section("failed tests ignored") + for msg in allow_fail_messages: + terminalreporter.line(msg) diff --git a/tests/constraints.txt b/tests/constraints.txt new file mode 100644 index 0000000..ef03ba1 --- /dev/null +++ b/tests/constraints.txt @@ -0,0 +1,32 @@ +# This is a constraint file forcing the minimum allowed version to be +# installed. +# +# https://pip.pypa.io/en/stable/user_guide/#constraints-files + +# From install_requires +backports.zoneinfo == 0.2.0 +typing-extensions == 4.1.0 + +# From the 'test' extra +mypy == 0.981 +pproxy == 2.7.0 +pytest == 6.2.5 +pytest-asyncio == 0.17.0 +pytest-cov == 3.0.0 +pytest-randomly == 3.10.0 + +# From the 'dev' extra +black == 22.3.0 +dnspython == 2.1.0 +flake8 == 4.0.0 +mypy == 0.981 +types-setuptools == 57.4.0 +wheel == 0.37 + +# From the 'docs' extra +Sphinx == 4.2.0 +furo == 2021.11.23 +sphinx-autobuild == 2021.3.14 +sphinx-autodoc-typehints == 1.12.0 +dnspython == 2.1.0 +shapely == 1.7.0 diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/crdb/__init__.py diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py new file mode 100644 index 0000000..ce5bacf --- /dev/null +++ b/tests/crdb/test_adapt.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +import pytest + +from psycopg.crdb import adapters, CrdbConnection + +from psycopg.adapt import PyFormat, Transformer +from psycopg.types.array import ListDumper +from psycopg.postgres import types as builtins + +from ..test_adapt import MyStr, make_dumper, make_bin_dumper +from ..test_adapt import make_loader, make_bin_loader + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as text oid to CockroachDB, unlike Postgres, + # to which strings are passed as unknown. This is because CRDB doesn't + # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises + # an error. However, unlike PostgreSQL, text can be cast to any other type. + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + + +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +@pytest.fixture +def crdb_adapters(): + """Restore the crdb adapters after a test has changed them.""" + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +def test_dump_global_ctx(dsn, crdb_adapters, pgconn): + adapters.register_dumper(MyStr, make_bin_dumper("gb")) + adapters.register_dumper(MyStr, make_dumper("gt")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_load_global_ctx(dsn, crdb_adapters): + adapters.register_loader("text", make_loader("gt")) + adapters.register_loader("text", make_bin_loader("gb")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py new file mode 100644 index 0000000..b2a69ef --- /dev/null +++ b/tests/crdb/test_connection.py @@ -0,0 +1,86 @@ +import time +import threading + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb + + +def test_is_crdb(conn): + assert CrdbConnection.is_crdb(conn) + assert CrdbConnection.is_crdb(conn.pgconn) + + +def test_connect(dsn): + with CrdbConnection.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + with psycopg.crdb.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + +def test_xid(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +def test_tpc_begin(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_begin("foo") + + +def test_tpc_recover(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_recover() + + +@pytest.mark.slow +def test_broken_connection(conn): + cur = conn.cursor() + (session_id,) = cur.execute("select session_id from [show session_id]").fetchone() + with pytest.raises(psycopg.DatabaseError): + cur.execute("cancel session %s", [session_id]) + assert conn.closed + + +@pytest.mark.slow +def test_broken(conn): + (session_id,) = conn.execute("show session_id").fetchone() + with pytest.raises(psycopg.OperationalError): + conn.execute("cancel session %s", [session_id]) + + assert conn.closed + assert conn.broken + conn.close() + assert conn.closed + assert conn.broken + + +@pytest.mark.slow +def test_identify_closure(conn_cls, dsn): + with conn_cls.connect(dsn, autocommit=True) as conn: + with conn_cls.connect(dsn, autocommit=True) as conn2: + (session_id,) = conn.execute("show session_id").fetchone() + + def closer(): + time.sleep(0.2) + conn2.execute("cancel session %s", [session_id]) + + t = threading.Thread(target=closer) + t.start() + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + # CRDB seems to take not less than 1s + assert 0.2 < dt < 2 + finally: + t.join() diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py new file mode 100644 index 0000000..b568e42 --- /dev/null +++ b/tests/crdb/test_connection_async.py @@ -0,0 +1,85 @@ +import time +import asyncio + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import AsyncCrdbConnection +from psycopg._compat import create_task + +import pytest + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +async def test_is_crdb(aconn): + assert AsyncCrdbConnection.is_crdb(aconn) + assert AsyncCrdbConnection.is_crdb(aconn.pgconn) + + +async def test_connect(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) + + +async def test_xid(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +async def test_tpc_begin(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_begin("foo") + + +async def test_tpc_recover(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_recover() + + +@pytest.mark.slow +async def test_broken_connection(aconn): + cur = aconn.cursor() + await cur.execute("select session_id from [show session_id]") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("cancel session %s", [session_id]) + assert aconn.closed + + +@pytest.mark.slow +async def test_broken(aconn): + cur = await aconn.execute("show session_id") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.OperationalError): + await aconn.execute("cancel session %s", [session_id]) + + assert aconn.closed + assert aconn.broken + await aconn.close() + assert aconn.closed + assert aconn.broken + + +@pytest.mark.slow +async def test_identify_closure(aconn_cls, dsn): + async with await aconn_cls.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn2: + cur = await conn.execute("show session_id") + (session_id,) = await cur.fetchone() + + async def closer(): + await asyncio.sleep(0.2) + await conn2.execute("cancel session %s", [session_id]) + + t = create_task(closer()) + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + await conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + assert 0.2 < dt < 2 + finally: + await asyncio.gather(t) diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py new file mode 100644 index 0000000..274a0c0 --- /dev/null +++ b/tests/crdb/test_conninfo.py @@ -0,0 +1,21 @@ +import pytest + +pytestmark = pytest.mark.crdb + + +def test_vendor(conn): + assert conn.info.vendor == "CockroachDB" + + +def test_server_version(conn): + assert conn.info.server_version > 200000 + + +@pytest.mark.crdb("< 22") +def test_backend_pid_pre_22(conn): + assert conn.info.backend_pid == 0 + + +@pytest.mark.crdb(">= 22") +def test_backend_pid(conn): + assert conn.info.backend_pid > 0 diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py new file mode 100644 index 0000000..b7d26aa --- /dev/null +++ b/tests/crdb/test_copy.py @@ -0,0 +1,233 @@ +import pytest +import string +from random import randrange, choice + +from psycopg import sql, errors as e +from psycopg.pq import Format +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import ensure_table, sample_records +from ..test_copy import sample_tabledef as sample_tabledef_pg + +# CRDB int/serial are int8 +sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin with binary") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +def test_copy_big_size_block(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select col2, data from copy_in order by 2").fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, "col1 int primary key, col2 int, data text") + + with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + copy.write_row(row) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def copyopt(format): + return "with binary" if format == Format.BINARY else "" diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py new file mode 100644 index 0000000..d5fbf50 --- /dev/null +++ b/tests/crdb/test_copy_async.py @@ -0,0 +1,235 @@ +import pytest +import string +from random import randrange, choice + +from psycopg.pq import Format +from psycopg import sql, errors as e +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import sample_records +from ..test_copy_async import ensure_table +from .test_copy import sample_tabledef, copyopt + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin with binary") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +@pytest.mark.slow +async def test_copy_big_size_block(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin {copyopt(format)}" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select col2, data from copy_in order by 2") + data = await cur.fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, "col1 int primary key, col2 int, data text") + + async with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py new file mode 100644 index 0000000..991b084 --- /dev/null +++ b/tests/crdb/test_cursor.py @@ -0,0 +1,65 @@ +import json +import threading +from uuid import uuid4 +from queue import Queue +from typing import Any + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.crdb + + +@pytest.fixture +def testfeed(svcconn): + name = f"test_feed_{str(uuid4()).replace('-', '')}" + svcconn.execute("set cluster setting kv.rangefeed.enabled to true") + svcconn.execute(f"create table {name} (id serial primary key, data text)") + yield name + svcconn.execute(f"drop table {name}") + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): + conn.autocommit = True + q: "Queue[Any]" = Queue() + + def worker(): + try: + with conn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + for row in cur.stream(f"experimental changefeed for {testfeed}"): + q.put(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put(None) + except Exception as ex: + q.put(ex) + + t = threading.Thread(target=worker) + t.start() + + cur = conn.cursor() + cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = cur.fetchone() + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = cur.fetchone() + cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert q.get(timeout=1) is None + t.join() diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py new file mode 100644 index 0000000..229295d --- /dev/null +++ b/tests/crdb/test_cursor_async.py @@ -0,0 +1,61 @@ +import json +import asyncio +from typing import Any +from asyncio.queues import Queue + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row +from psycopg._compat import create_task + +from .test_cursor import testfeed + +testfeed # fixture + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): + await aconn.set_autocommit(True) + q: "Queue[Any]" = Queue() + + async def worker(): + try: + async with await aconn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + async for row in cur.stream( + f"experimental changefeed for {testfeed}" + ): + q.put_nowait(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put_nowait(None) + except Exception as ex: + q.put_nowait(ex) + + t = create_task(worker()) + + cur = aconn.cursor() + await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = await cur.fetchone() + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + await cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + await cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = await cur.fetchone() + await cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert await asyncio.wait_for(q.get(), 1.0) is None + await asyncio.gather(t) diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py new file mode 100644 index 0000000..df43f3b --- /dev/null +++ b/tests/crdb/test_no_crdb.py @@ -0,0 +1,34 @@ +from psycopg.pq import TransactionStatus +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb("skip") + + +def test_is_crdb(conn): + assert not CrdbConnection.is_crdb(conn) + assert not CrdbConnection.is_crdb(conn.pgconn) + + +def test_tpc_on_pg_connection(conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py new file mode 100644 index 0000000..2cff0a7 --- /dev/null +++ b/tests/crdb/test_typing.py @@ -0,0 +1,49 @@ +import pytest + +from ..test_typing import _test_reveal + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.crdb.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect()", + "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal_crdb(stmts, type, mypy) + + +def _test_reveal_crdb(stmts, type, mypy): + stmts = f"""\ +import psycopg.crdb +{stmts} +""" + _test_reveal(stmts, type, mypy) diff --git a/tests/dbapi20.py b/tests/dbapi20.py new file mode 100644 index 0000000..c873a4e --- /dev/null +++ b/tests/dbapi20.py @@ -0,0 +1,870 @@ +#!/usr/bin/env python +# flake8: noqa +# fmt: off +''' Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +''' + +__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' +__version__ = '$Revision: 1.12 $'[11:-2] +__author__ = 'Stuart Bishop <stuart@stuartbishop.net>' + +import unittest +import time +import sys +from typing import Any, Dict + + +# Revision 1.12 2009/02/06 03:35:11 kf7xm +# Tested okay with Python 3.0, includes last minute patches from Mark H. +# +# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole +# Include latest changes from main branch +# Updates for py3k +# +# Revision 1.11 2005/01/02 02:41:01 zenzen +# Update author email address +# +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception hierarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propagates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# + +class DatabaseAPI20Test(unittest.TestCase): + ''' Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + from . import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'from .dbapi20 import DatabaseAPI20Test', or you will + confuse the unit tester - just 'from . import dbapi20'. + ''' + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver: Any = None + connect_args = () # List of arguments to pass to connect + connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self,cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self,cursor): + cursor.execute(self.ddl2) + + def setUp(self): + ''' self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + ''' + pass + + def tearDown(self): + ''' self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + ''' + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1,self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + + def _connect(self): + try: + return self.driver.connect( + *self.connect_args,**self.connect_kw_args + ) + except AttributeError: + self.fail("No connect method found in self.driver module") + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel,'2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + self.failUnless(threadsafety in (0,1,2,3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + self.failUnless(paramstyle in ( + 'qmark','numeric','named','format','pyformat' + )) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined hierarchy. + if sys.version[0] == '3': #under Python 3 StardardError no longer exists + self.failUnless(issubclass(self.driver.Warning,Exception)) + self.failUnless(issubclass(self.driver.Error,Exception)) + else: + self.failUnless(issubclass(self.driver.Warning,StandardError)) # type: ignore[name-defined] + self.failUnless(issubclass(self.driver.Error,StandardError)) # type: ignore[name-defined] + + self.failUnless( + issubclass(self.driver.InterfaceError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.DatabaseError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.OperationalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.IntegrityError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.InternalError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.ProgrammingError,self.driver.Error) + ) + self.failUnless( + issubclass(self.driver.NotSupportedError,self.driver.Error) + ) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + con = self._connect() + drv = self.driver + self.failUnless(con.Warning is drv.Warning) + self.failUnless(con.Error is drv.Error) + self.failUnless(con.InterfaceError is drv.InterfaceError) + self.failUnless(con.DatabaseError is drv.DatabaseError) + self.failUnless(con.OperationalError is drv.OperationalError) + self.failUnless(con.IntegrityError is drv.IntegrityError) + self.failUnless(con.InternalError is drv.InternalError) + self.failUnless(con.ProgrammingError is drv.ProgrammingError) + self.failUnless(con.NotSupportedError is drv.NotSupportedError) + con.close() + + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con,'rollback'): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + con.close() + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze),1) + self.assertEqual(len(booze[0]),1) + self.assertEqual(booze[0][0],'Victoria Bitter') + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.description,None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(len(cur.description),1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]),7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(),'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1],self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual(cur.description,None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount should be -1 after executing no-result ' + 'statements' + ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.failUnless(cur.rowcount in (-1,1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount not being reset to -1 after executing ' + 'no-result statements' + ) + finally: + con.close() + + lower_func = 'lower' + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + r = cur.callproc(self.lower_func,('FOO',)) + self.assertEqual(len(r),1) + self.assertEqual(r[0],'FOO') + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,con.commit) + + # connection.close should raise an Error if called more than once + # Issue discussed on DB-SIG: consensus seem that close() should not + # raised if called on closed objects. Issue reported back to Stuart. + # self.assertRaises(self.driver.Error,con.close) + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def _paraminsert(self,cur): + self.executeDDL1(cur) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.failUnless(cur.rowcount in (-1,1)) + + if self.driver.paramstyle == 'qmark': + cur.execute( + 'insert into %sbooze values (?)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'numeric': + cur.execute( + 'insert into %sbooze values (:1)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'named': + cur.execute( + 'insert into %sbooze values (:beer)' % self.table_prefix, + {'beer':"Cooper's"} + ) + elif self.driver.paramstyle == 'format': + cur.execute( + 'insert into %sbooze values (%%s)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'pyformat': + cur.execute( + 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, + {'beer':"Cooper's"} + ) + else: + self.fail('Invalid paramstyle') + self.failUnless(cur.rowcount in (-1,1)) + + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1],"Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [ ("Cooper's",) , ("Boag's",) ] + margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.failUnless(cur.rowcount in (-1,2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') + self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if no more rows available' + ) + self.failUnless(cur.rowcount in (-1,1)) + finally: + con.close() + + samples = [ + 'Carlton Cold', + 'Carlton Draft', + 'Mountain Goat', + 'Redback', + 'Victoria Bitter', + 'XXXX' + ] + + def _populate(self): + ''' Return a list of sql commands to setup the DB for the fetch + tests. + ''' + populate = [ + "insert into %sbooze values ('%s')" % (self.table_prefix,s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + #issuing a query + self.assertRaises(self.driver.Error,cur.fetchmany,4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() + self.assertEqual(len(r),1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize=10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r),3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r),2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.failUnless(cur.rowcount in (-1,6)) + + # Same as above, using cursor.arraysize + cur.arraysize=4 + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r),4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r),2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r),0) + self.failUnless(cur.rowcount in (-1,6)) + + cur.arraysize=6 + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows),6) + self.assertEqual(len(rows),6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0,6): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows),0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,6)) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.failUnless(cur.rowcount in (-1,0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.failUnless(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.failUnless(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + self.failUnless(cur.rowcount in (-1,6)) + self.assertEqual(len(rows23),2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56),2, + 'fetchall returned incorrect number of rows' + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0],rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0],rows56[1][0]]) + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved or inserted' + ) + finally: + con.close() + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + raise NotImplementedError('Helper not implemented') + #sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + #""" + #cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + raise NotImplementedError('Helper not implemented') + #cur.execute("drop procedure deleteme") + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + assert s is None, 'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.failUnless(hasattr(cur,'arraysize'), + 'cursor.arraysize must be defined' + ) + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes( (25,) ) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000,0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependent + raise NotImplementedError('Driver needed to override this test') + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r),1) + self.assertEqual(len(r[0]),1) + self.assertEqual(r[0][0],None,'NULL value not returned as None') + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002,12,25) + d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13,45,30) + t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002,12,25,13,45,30,0,0,0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary(b'Something') + b = self.driver.Binary(b'') + + def test_STRING(self): + self.failUnless(hasattr(self.driver,'STRING'), + 'module.STRING must be defined' + ) + + def test_BINARY(self): + self.failUnless(hasattr(self.driver,'BINARY'), + 'module.BINARY must be defined.' + ) + + def test_NUMBER(self): + self.failUnless(hasattr(self.driver,'NUMBER'), + 'module.NUMBER must be defined.' + ) + + def test_DATETIME(self): + self.failUnless(hasattr(self.driver,'DATETIME'), + 'module.DATETIME must be defined.' + ) + + def test_ROWID(self): + self.failUnless(hasattr(self.driver,'ROWID'), + 'module.ROWID must be defined.' + ) +# fmt: on diff --git a/tests/dbapi20_tpc.py b/tests/dbapi20_tpc.py new file mode 100644 index 0000000..7254294 --- /dev/null +++ b/tests/dbapi20_tpc.py @@ -0,0 +1,151 @@ +# flake8: noqa +# fmt: off + +""" Python DB API 2.0 driver Two Phase Commit compliance test suite. + +""" + +import unittest +from typing import Any + + +class TwoPhaseCommitTests(unittest.TestCase): + + driver: Any = None + + def connect(self): + """Make a database connection.""" + raise NotImplementedError + + _last_id = 0 + _global_id_prefix = "dbapi20_tpc:" + + def make_xid(self, con): + id = TwoPhaseCommitTests._last_id + TwoPhaseCommitTests._last_id += 1 + return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier") + + def test_xid(self): + con = self.connect() + try: + try: + xid = con.xid(42, "global", "bqual") + except self.driver.NotSupportedError: + self.fail("Driver does not support transaction IDs.") + + self.assertEquals(xid[0], 42) + self.assertEquals(xid[1], "global") + self.assertEquals(xid[2], "bqual") + + # Try some extremes for the transaction ID: + xid = con.xid(0, "", "") + self.assertEquals(tuple(xid), (0, "", "")) + xid = con.xid(0x7fffffff, "a" * 64, "b" * 64) + self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64)) + finally: + con.close() + + def test_tpc_begin(self): + con = self.connect() + try: + xid = self.make_xid(con) + try: + con.tpc_begin(xid) + except self.driver.NotSupportedError: + self.fail("Driver does not support tpc_begin()") + finally: + con.close() + + def test_tpc_commit_without_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_commit() + finally: + con.close() + + def test_tpc_rollback_without_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_rollback() + finally: + con.close() + + def test_tpc_commit_with_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_prepare() + con.tpc_commit() + finally: + con.close() + + def test_tpc_rollback_with_prepare(self): + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + cursor = con.cursor() + cursor.execute("SELECT 1") + con.tpc_prepare() + con.tpc_rollback() + finally: + con.close() + + def test_tpc_begin_in_transaction_fails(self): + con = self.connect() + try: + xid = self.make_xid(con) + + cursor = con.cursor() + cursor.execute("SELECT 1") + self.assertRaises(self.driver.ProgrammingError, + con.tpc_begin, xid) + finally: + con.close() + + def test_tpc_begin_in_tpc_transaction_fails(self): + con = self.connect() + try: + xid = self.make_xid(con) + + cursor = con.cursor() + cursor.execute("SELECT 1") + self.assertRaises(self.driver.ProgrammingError, + con.tpc_begin, xid) + finally: + con.close() + + def test_commit_in_tpc_fails(self): + # calling commit() within a TPC transaction fails with + # ProgrammingError. + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + + self.assertRaises(self.driver.ProgrammingError, con.commit) + finally: + con.close() + + def test_rollback_in_tpc_fails(self): + # calling rollback() within a TPC transaction fails with + # ProgrammingError. + con = self.connect() + try: + xid = self.make_xid(con) + con.tpc_begin(xid) + + self.assertRaises(self.driver.ProgrammingError, con.rollback) + finally: + con.close() diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py new file mode 100644 index 0000000..88ab504 --- /dev/null +++ b/tests/fix_crdb.py @@ -0,0 +1,131 @@ +from typing import Optional + +import pytest + +from .utils import VersionCheck +from psycopg.crdb import CrdbConnection + + +def pytest_configure(config): + # register libpq marker + config.addinivalue_line( + "markers", + "crdb(version_expr, reason=detail): run/skip the test with matching CockroachDB" + " (e.g. '>= 21.2.10', '< 22.1', 'skip < 22')", + ) + config.addinivalue_line( + "markers", + "crdb_skip(reason): skip the test for known CockroachDB reasons", + ) + + +def check_crdb_version(got, mark): + if mark.name == "crdb": + assert len(mark.args) <= 1 + assert not (set(mark.kwargs) - {"reason"}) + spec = mark.args[0] if mark.args else "only" + reason = mark.kwargs.get("reason") + elif mark.name == "crdb_skip": + assert len(mark.args) == 1 + assert not mark.kwargs + reason = mark.args[0] + assert reason in _crdb_reasons, reason + spec = _crdb_reason_version.get(reason, "skip") + else: + assert False, mark.name + + pred = VersionCheck.parse(spec) + pred.whose = "CockroachDB" + + msg = pred.get_skip_message(got) + if not msg: + return None + + reason = crdb_skip_message(reason) + if reason: + msg = f"{msg}: {reason}" + + return msg + + +# Utility functions which can be imported in the test suite + +is_crdb = CrdbConnection.is_crdb + + +def crdb_skip_message(reason: Optional[str]) -> str: + msg = "" + if reason: + msg = reason + if _crdb_reasons.get(reason): + url = ( + "https://github.com/cockroachdb/cockroach/" + f"issues/{_crdb_reasons[reason]}" + ) + msg = f"{msg} ({url})" + + return msg + + +def skip_crdb(*args, reason=None): + return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason)) + + +def crdb_encoding(*args): + """Mark tests that fail on CockroachDB because of missing encodings""" + return skip_crdb(*args, reason="encoding") + + +def crdb_time_precision(*args): + """Mark tests that fail on CockroachDB because time doesn't support precision""" + return skip_crdb(*args, reason="time precision") + + +def crdb_scs_off(*args): + return skip_crdb(*args, reason="standard_conforming_strings=off") + + +# mapping from reason description to ticket number +_crdb_reasons = { + "2-phase commit": 22329, + "backend pid": 35897, + "batch statements": 44803, + "begin_read_only": 87012, + "binary decimal": 82492, + "cancel": 41335, + "cast adds tz": 51692, + "cidr": 18846, + "composite": 27792, + "copy array": 82792, + "copy canceled": 81559, + "copy": 41608, + "cursor invalid name": 84261, + "cursor with hold": 77101, + "deferrable": 48307, + "do": 17511, + "encoding": 35882, + "geometric types": 21286, + "hstore": 41284, + "infinity date": 41564, + "interval style": 35807, + "json array": 23468, + "large objects": 243, + "negative interval": 81577, + "nested array": 32552, + "no col query": None, + "notify": 41522, + "password_encryption": 42519, + "pg_terminate_backend": 35897, + "range": 41282, + "scroll cursor": 77102, + "server-side cursor": 41412, + "severity_nonlocalized": 81794, + "stored procedure": 1751, +} + +_crdb_reason_version = { + "backend pid": "skip < 22", + "cancel": "skip < 22", + "server-side cursor": "skip < 22.1.3", + "severity_nonlocalized": "skip < 22.1.3", +} diff --git a/tests/fix_db.py b/tests/fix_db.py new file mode 100644 index 0000000..3a37aa1 --- /dev/null +++ b/tests/fix_db.py @@ -0,0 +1,358 @@ +import io +import os +import sys +import pytest +import logging +from contextlib import contextmanager +from typing import Optional + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg._compat import cache +from psycopg.pq._debug import PGconnDebug + +from .utils import check_postgres_version + +# Set by warm_up_database() the first time the dsn fixture is used +pg_version: int +crdb_version: Optional[int] + + +def pytest_addoption(parser): + parser.addoption( + "--test-dsn", + metavar="DSN", + default=os.environ.get("PSYCOPG_TEST_DSN"), + help=( + "Connection string to run database tests requiring a connection" + " [you can also use the PSYCOPG_TEST_DSN env var]." + ), + ) + parser.addoption( + "--pq-trace", + metavar="{TRACEFILE,STDERR}", + default=None, + help="Generate a libpq trace to TRACEFILE or STDERR.", + ) + parser.addoption( + "--pq-debug", + action="store_true", + default=False, + help="Log PGconn access. (Requires PSYCOPG_IMPL=python.)", + ) + + +def pytest_report_header(config): + dsn = config.getoption("--test-dsn") + if dsn is None: + return [] + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + server_version = conn.execute("select version()").fetchall()[0][0] + except Exception as ex: + server_version = f"unknown ({ex})" + + return [ + f"Server version: {server_version}", + ] + + +def pytest_collection_modifyitems(items): + for item in items: + for name in item.fixturenames: + if name in ("pipeline", "apipeline"): + item.add_marker(pytest.mark.pipeline) + break + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="pipeline"): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + + +def pytest_configure(config): + # register pg marker + markers = [ + "pg(version_expr): run the test only with matching server version" + " (e.g. '>= 10', '< 9.6')", + "pipeline: the test runs with connection in pipeline mode", + ] + for marker in markers: + config.addinivalue_line("markers", marker) + + +@pytest.fixture(scope="session") +def session_dsn(request): + """ + Return the dsn used to connect to the `--test-dsn` database (session-wide). + """ + dsn = request.config.getoption("--test-dsn") + if dsn is None: + pytest.skip("skipping test as no --test-dsn") + + warm_up_database(dsn) + return dsn + + +@pytest.fixture +def dsn(session_dsn, request): + """Return the dsn used to connect to the `--test-dsn` database.""" + check_connection_version(request.node) + return session_dsn + + +@pytest.fixture(scope="session") +def tracefile(request): + """Open and yield a file for libpq client/server communication traces if + --pq-tracefile option is set. + """ + tracefile = request.config.getoption("--pq-trace") + if not tracefile: + yield None + return + + if tracefile.lower() == "stderr": + try: + sys.stderr.fileno() + except io.UnsupportedOperation: + raise pytest.UsageError( + "cannot use stderr for --pq-trace (in-memory file?)" + ) from None + + yield sys.stderr + return + + with open(tracefile, "w") as f: + yield f + + +@contextmanager +def maybe_trace(pgconn, tracefile, function): + """Handle libpq client/server communication traces for a single test + function. + """ + if tracefile is None: + yield None + return + + if tracefile != sys.stderr: + title = f" {function.__module__}::{function.__qualname__} ".center(80, "=") + tracefile.write(title + "\n") + tracefile.flush() + + pgconn.trace(tracefile.fileno()) + try: + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + except psycopg.NotSupportedError: + pass + try: + yield None + finally: + pgconn.untrace() + + +@pytest.fixture(autouse=True) +def pgconn_debug(request): + if not request.config.getoption("--pq-debug"): + return + if pq.__impl__ != "python": + raise pytest.UsageError("set PSYCOPG_IMPL=python to use --pq-debug") + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + pq.PGconn = PGconnDebug + + +@pytest.fixture +def pgconn(dsn, request, tracefile): + """Return a PGconn connection open to `--test-dsn`.""" + check_connection_version(request.node) + + conn = pq.PGconn.connect(dsn.encode()) + if conn.status != pq.ConnStatus.OK: + pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}") + + with maybe_trace(conn, tracefile, request.function): + yield conn + + conn.finish() + + +@pytest.fixture +def conn(conn_cls, dsn, request, tracefile): + """Return a `Connection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = conn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +def pipeline(request, conn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + with conn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture +async def aconn(dsn, aconn_cls, request, tracefile): + """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" + check_connection_version(request.node) + + conn = await aconn_cls.connect(dsn) + with maybe_trace(conn.pgconn, tracefile, request.function): + yield conn + await conn.close() + + +@pytest.fixture(params=[True, False], ids=["pipeline=on", "pipeline=off"]) +async def apipeline(request, aconn): + if request.param: + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + async with aconn.pipeline() as p: + yield p + return + else: + yield None + + +@pytest.fixture(scope="session") +def conn_cls(session_dsn): + cls = psycopg.Connection + if crdb_version: + from psycopg.crdb import CrdbConnection + + cls = CrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def aconn_cls(session_dsn): + cls = psycopg.AsyncConnection + if crdb_version: + from psycopg.crdb import AsyncCrdbConnection + + cls = AsyncCrdbConnection + + return cls + + +@pytest.fixture(scope="session") +def svcconn(conn_cls, session_dsn): + """ + Return a session `Connection` connected to the ``--test-dsn`` database. + """ + conn = conn_cls.connect(session_dsn, autocommit=True) + yield conn + conn.close() + + +@pytest.fixture +def commands(conn, monkeypatch): + """The list of commands issued internally by the test connection.""" + yield patch_exec(conn, monkeypatch) + + +@pytest.fixture +def acommands(aconn, monkeypatch): + """The list of commands issued internally by the test async connection.""" + yield patch_exec(aconn, monkeypatch) + + +def patch_exec(conn, monkeypatch): + """Helper to implement the commands fixture both sync and async.""" + _orig_exec_command = conn._exec_command + L = ListPopAll() + + def _exec_command(command, *args, **kwargs): + cmdcopy = command + if isinstance(cmdcopy, bytes): + cmdcopy = cmdcopy.decode(conn.info.encoding) + elif isinstance(cmdcopy, sql.Composable): + cmdcopy = cmdcopy.as_string(conn) + + L.append(cmdcopy) + return _orig_exec_command(command, *args, **kwargs) + + monkeypatch.setattr(conn, "_exec_command", _exec_command) + return L + + +class ListPopAll(list): # type: ignore[type-arg] + """A list, with a popall() method.""" + + def popall(self): + out = self[:] + del self[:] + return out + + +def check_connection_version(node): + try: + pg_version + except NameError: + # First connection creation failed. Let the tests fail. + pytest.fail("server version not available") + + for mark in node.iter_markers(): + if mark.name == "pg": + assert len(mark.args) == 1 + msg = check_postgres_version(pg_version, mark.args[0]) + if msg: + pytest.skip(msg) + + elif mark.name in ("crdb", "crdb_skip"): + from .fix_crdb import check_crdb_version + + msg = check_crdb_version(crdb_version, mark) + if msg: + pytest.skip(msg) + + +@pytest.fixture +def hstore(svcconn): + try: + with svcconn.transaction(): + svcconn.execute("create extension if not exists hstore") + except psycopg.Error as e: + pytest.skip(str(e)) + + +@cache +def warm_up_database(dsn: str) -> None: + """Connect to the database before returning a connection. + + In the CI sometimes, the first test fails with a timeout, probably because + the server hasn't started yet. Absorb the delay before the test. + + In case of error, abort the test run entirely, to avoid failing downstream + hundreds of times. + """ + global pg_version, crdb_version + + try: + with psycopg.connect(dsn, connect_timeout=10) as conn: + conn.execute("select 1") + + pg_version = conn.info.server_version + + crdb_version = None + param = conn.info.parameter_status("crdb_version") + if param: + from psycopg.crdb import CrdbConnectionInfo + + crdb_version = CrdbConnectionInfo.parse_crdb_version(param) + except Exception as exc: + pytest.exit(f"failed to connect to the test database: {exc}") diff --git a/tests/fix_faker.py b/tests/fix_faker.py new file mode 100644 index 0000000..5289d8f --- /dev/null +++ b/tests/fix_faker.py @@ -0,0 +1,868 @@ +import datetime as dt +import importlib +import ipaddress +from math import isnan +from uuid import UUID +from random import choice, random, randrange +from typing import Any, List, Set, Tuple, Union +from decimal import Decimal +from contextlib import contextmanager, asynccontextmanager + +import pytest + +import psycopg +from psycopg import sql +from psycopg.adapt import PyFormat +from psycopg._compat import Deque +from psycopg.types.range import Range +from psycopg.types.json import Json, Jsonb +from psycopg.types.numeric import Int4, Int8 +from psycopg.types.multirange import Multirange + + +@pytest.fixture +def faker(conn): + return Faker(conn) + + +class Faker: + """ + An object to generate random records. + """ + + json_max_level = 3 + json_max_length = 10 + str_max_length = 100 + list_max_length = 20 + tuple_max_length = 15 + + def __init__(self, connection): + self.conn = connection + self.format = PyFormat.BINARY + self.records = [] + + self._schema = None + self._types = None + self._types_names = None + self._makers = {} + self.table_name = sql.Identifier("fake_table") + + @property + def schema(self): + if not self._schema: + self.schema = self.choose_schema() + return self._schema + + @schema.setter + def schema(self, schema): + self._schema = schema + self._types_names = None + + @property + def fields_names(self): + return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))] + + @property + def types(self): + if not self._types: + + def key(cls: type) -> str: + return cls.__name__ + + self._types = sorted(self.get_supported_types(), key=key) + return self._types + + @property + def types_names_sql(self): + if self._types_names: + return self._types_names + + record = self.make_record(nulls=0) + tx = psycopg.adapt.Transformer(self.conn) + types = [ + self._get_type_name(tx, schema, value) + for schema, value in zip(self.schema, record) + ] + self._types_names = types + return types + + @property + def types_names(self): + types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql] + return types + + def _get_type_name(self, tx, schema, value): + # Special case it as it is passed as unknown so is returned as text + if schema == (list, str): + return sql.SQL("text[]") + + registry = self.conn.adapters.types + dumper = tx.get_dumper(value, self.format) + dumper.dump(value) # load the oid if it's dynamic (e.g. array) + info = registry.get(dumper.oid) or registry.get("text") + if dumper.oid == info.array_oid: + return sql.SQL("{}[]").format(sql.Identifier(info.name)) + else: + return sql.Identifier(info.name) + + @property + def drop_stmt(self): + return sql.SQL("drop table if exists {}").format(self.table_name) + + @property + def create_stmt(self): + field_values = [] + for name, type in zip(self.fields_names, self.types_names_sql): + field_values.append(sql.SQL("{} {}").format(name, type)) + + fields = sql.SQL(", ").join(field_values) + return sql.SQL("create table {table} (id serial primary key, {fields})").format( + table=self.table_name, fields=fields + ) + + @property + def insert_stmt(self): + phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))] + return sql.SQL("insert into {} ({}) values ({})").format( + self.table_name, + sql.SQL(", ").join(self.fields_names), + sql.SQL(", ").join(phs), + ) + + @property + def select_stmt(self): + fields = sql.SQL(", ").join(self.fields_names) + return sql.SQL("select {} from {} order by id").format(fields, self.table_name) + + @contextmanager + def find_insert_problem(self, conn): + """Context manager to help finding a problematic value.""" + try: + with conn.transaction(): + yield + except psycopg.DatabaseError: + cur = conn.cursor() + # Repeat insert one field at time, until finding the wrong one + cur.execute(self.drop_stmt) + cur.execute(self.create_stmt) + for i, rec in enumerate(self.records): + for j, val in enumerate(rec): + try: + cur.execute(self._insert_field_stmt(j), (val,)) + except psycopg.DatabaseError as e: + r = repr(val) + if len(r) > 200: + r = f"{r[:200]}... ({len(r)} chars)" + raise Exception( + f"value {r!r} at record {i} column0 {j} failed insert: {e}" + ) from None + + # just in case, but hopefully we should have triggered the problem + raise + + @asynccontextmanager + async def find_insert_problem_async(self, aconn): + try: + async with aconn.transaction(): + yield + except psycopg.DatabaseError: + acur = aconn.cursor() + # Repeat insert one field at time, until finding the wrong one + await acur.execute(self.drop_stmt) + await acur.execute(self.create_stmt) + for i, rec in enumerate(self.records): + for j, val in enumerate(rec): + try: + await acur.execute(self._insert_field_stmt(j), (val,)) + except psycopg.DatabaseError as e: + r = repr(val) + if len(r) > 200: + r = f"{r[:200]}... ({len(r)} chars)" + raise Exception( + f"value {r!r} at record {i} column0 {j} failed insert: {e}" + ) from None + + # just in case, but hopefully we should have triggered the problem + raise + + def _insert_field_stmt(self, i): + ph = sql.Placeholder(format=self.format) + return sql.SQL("insert into {} ({}) values ({})").format( + self.table_name, self.fields_names[i], ph + ) + + def choose_schema(self, ncols=20): + schema: List[Union[Tuple[type, ...], type]] = [] + while len(schema) < ncols: + s = self.make_schema(choice(self.types)) + if s is not None: + schema.append(s) + self.schema = schema + return schema + + def make_records(self, nrecords): + self.records = [self.make_record(nulls=0.05) for i in range(nrecords)] + + def make_record(self, nulls=0): + if not nulls: + return tuple(self.example(spec) for spec in self.schema) + else: + return tuple( + self.make(spec) if random() > nulls else None for spec in self.schema + ) + + def assert_record(self, got, want): + for spec, g, w in zip(self.schema, got, want): + if g is None and w is None: + continue + m = self.get_matcher(spec) + m(spec, g, w) + + def get_supported_types(self) -> Set[type]: + dumpers = self.conn.adapters._dumpers[self.format] + rv = set() + for cls in dumpers.keys(): + if isinstance(cls, str): + cls = deep_import(cls) + if issubclass(cls, Multirange) and self.conn.info.server_version < 140000: + continue + + rv.add(cls) + + # check all the types are handled + for cls in rv: + self.get_maker(cls) + + return rv + + def make_schema(self, cls: type) -> Union[Tuple[type, ...], type, None]: + """Create a schema spec from a Python type. + + A schema specifies what Postgres type to generate when a Python type + maps to more than one (e.g. tuple -> composite, list -> array[], + datetime -> timestamp[tz]). + + A schema for a type is represented by a tuple (type, ...) which the + matching make_*() method can interpret, or just type if the type + doesn't require further specification. + + A `None` means that the type is not supported. + """ + meth = self._get_method("schema", cls) + return meth(cls) if meth else cls + + def get_maker(self, spec): + cls = spec if isinstance(spec, type) else spec[0] + + try: + return self._makers[cls] + except KeyError: + pass + + meth = self._get_method("make", cls) + if meth: + self._makers[cls] = meth + return meth + else: + raise NotImplementedError(f"cannot make fake objects of class {cls}") + + def get_matcher(self, spec): + cls = spec if isinstance(spec, type) else spec[0] + meth = self._get_method("match", cls) + return meth if meth else self.match_any + + def _get_method(self, prefix, cls): + name = cls.__name__ + if cls.__module__ != "builtins": + name = f"{cls.__module__}.{name}" + + parts = name.split(".") + for i in range(len(parts)): + mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}" + meth = getattr(self, mname, None) + if meth: + return meth + + return None + + def make(self, spec): + # spec can be a type or a tuple (type, options) + return self.get_maker(spec)(spec) + + def example(self, spec): + # A good representative of the object - no degenerate case + cls = spec if isinstance(spec, type) else spec[0] + meth = self._get_method("example", cls) + if meth: + return meth(spec) + else: + return self.make(spec) + + def match_any(self, spec, got, want): + assert got == want + + # methods to generate samples of specific types + + def make_Binary(self, spec): + return self.make_bytes(spec) + + def match_Binary(self, spec, got, want): + return want.obj == got + + def make_bool(self, spec): + return choice((True, False)) + + def make_bytearray(self, spec): + return self.make_bytes(spec) + + def make_bytes(self, spec): + length = randrange(self.str_max_length) + return spec(bytes([randrange(256) for i in range(length)])) + + def make_date(self, spec): + day = randrange(dt.date.max.toordinal()) + return dt.date.fromordinal(day + 1) + + def schema_datetime(self, cls): + return self.schema_time(cls) + + def make_datetime(self, spec): + # Add a day because with timezone we might go BC + dtmin = dt.datetime.min + dt.timedelta(days=1) + delta = dt.datetime.max - dtmin + micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000) + rv = dtmin + dt.timedelta(microseconds=micros) + if spec[1]: + rv = rv.replace(tzinfo=self._make_tz(spec)) + return rv + + def match_datetime(self, spec, got, want): + # Comparisons with different timezones is unreliable: certain pairs + # are reported different but their delta is 0 + # https://bugs.python.org/issue45347 + assert not (got - want) + + def make_Decimal(self, spec): + if random() >= 0.99: + return Decimal(choice(self._decimal_special_values())) + + sign = choice("+-") + num = choice(["0.zd", "d", "d.d"]) + while "z" in num: + ndigits = randrange(1, 20) + num = num.replace("z", "0" * ndigits, 1) + while "d" in num: + ndigits = randrange(1, 20) + num = num.replace( + "d", "".join([str(randrange(10)) for i in range(ndigits)]), 1 + ) + expsign = choice(["e+", "e-", ""]) + exp = randrange(20) if expsign else "" + rv = Decimal(f"{sign}{num}{expsign}{exp}") + return rv + + def match_Decimal(self, spec, got, want): + if got is not None and got.is_nan(): + assert want.is_nan() + else: + assert got == want + + def _decimal_special_values(self): + values = ["NaN", "sNaN"] + + if self.conn.info.vendor == "PostgreSQL": + if self.conn.info.server_version >= 140000: + values.extend(["Inf", "-Inf"]) + elif self.conn.info.vendor == "CockroachDB": + if self.conn.info.server_version >= 220100: + values.extend(["Inf", "-Inf"]) + else: + pytest.fail(f"unexpected vendor: {self.conn.info.vendor}") + + return values + + def schema_Enum(self, cls): + # TODO: can't fake those as we would need to create temporary types + return None + + def make_Enum(self, spec): + return None + + def make_float(self, spec, double=True): + if random() <= 0.99: + # These exponents should generate no inf + return float( + f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}" + if double + else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}" + ) + else: + return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan"))) + + def match_float(self, spec, got, want, approx=False, rel=None): + if got is not None and isnan(got): + assert isnan(want) + else: + if approx or self._server_rounds(): + assert got == pytest.approx(want, rel=rel) + else: + assert got == want + + def _server_rounds(self): + """Return True if the connected server perform float rounding""" + if self.conn.info.vendor == "CockroachDB": + return True + else: + # Versions older than 12 make some rounding. e.g. in Postgres 10.4 + # select '-1.409006204063909e+112'::float8 + # -> -1.40900620406391e+112 + return self.conn.info.server_version < 120000 + + def make_Float4(self, spec): + return spec(self.make_float(spec, double=False)) + + def match_Float4(self, spec, got, want): + self.match_float(spec, got, want, approx=True, rel=1e-5) + + def make_Float8(self, spec): + return spec(self.make_float(spec)) + + match_Float8 = match_float + + def make_int(self, spec): + return randrange(-(1 << 90), 1 << 90) + + def make_Int2(self, spec): + return spec(randrange(-(1 << 15), 1 << 15)) + + def make_Int4(self, spec): + return spec(randrange(-(1 << 31), 1 << 31)) + + def make_Int8(self, spec): + return spec(randrange(-(1 << 63), 1 << 63)) + + def make_IntNumeric(self, spec): + return spec(randrange(-(1 << 100), 1 << 100)) + + def make_IPv4Address(self, spec): + return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4))) + + def make_IPv4Interface(self, spec): + prefix = randrange(32) + return ipaddress.IPv4Interface( + (bytes(randrange(256) for _ in range(4)), prefix) + ) + + def make_IPv4Network(self, spec): + return self.make_IPv4Interface(spec).network + + def make_IPv6Address(self, spec): + return ipaddress.IPv6Address(bytes(randrange(256) for _ in range(16))) + + def make_IPv6Interface(self, spec): + prefix = randrange(128) + return ipaddress.IPv6Interface( + (bytes(randrange(256) for _ in range(16)), prefix) + ) + + def make_IPv6Network(self, spec): + return self.make_IPv6Interface(spec).network + + def make_Json(self, spec): + return spec(self._make_json()) + + def match_Json(self, spec, got, want): + if want is not None: + want = want.obj + assert got == want + + def make_Jsonb(self, spec): + return spec(self._make_json()) + + def match_Jsonb(self, spec, got, want): + self.match_Json(spec, got, want) + + def make_JsonFloat(self, spec): + # A float limited to what json accepts + # this exponent should generate no inf + return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}") + + def schema_list(self, cls): + while True: + scls = choice(self.types) + if scls is cls: + continue + if scls is float: + # TODO: float lists are currently adapted as decimal. + # There may be rounding errors or problems with inf. + continue + + # CRDB doesn't support arrays of json + # https://github.com/cockroachdb/cockroach/issues/23468 + if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb): + continue + + schema = self.make_schema(scls) + if schema is not None: + break + + return (cls, schema) + + def make_list(self, spec): + # don't make empty lists because they regularly fail cast + length = randrange(1, self.list_max_length) + spec = spec[1] + while True: + rv = [self.make(spec) for i in range(length)] + + # TODO multirange lists fail binary dump if the last element is + # empty and there is no type annotation. See xfail in + # test_multirange::test_dump_builtin_array + if rv and isinstance(rv[-1], Multirange) and not rv[-1]: + continue + + return rv + + def example_list(self, spec): + return [self.example(spec[1])] + + def match_list(self, spec, got, want): + assert len(got) == len(want) + m = self.get_matcher(spec[1]) + for g, w in zip(got, want): + m(spec[1], g, w) + + def make_memoryview(self, spec): + return self.make_bytes(spec) + + def schema_Multirange(self, cls): + return self.schema_Range(cls) + + def make_Multirange(self, spec, length=None, **kwargs): + if length is None: + length = randrange(0, self.list_max_length) + + def overlap(r1, r2): + l1, u1 = r1.lower, r1.upper + l2, u2 = r2.lower, r2.upper + if l1 is None and l2 is None: + return True + elif l1 is None: + l1 = l2 + elif l2 is None: + l2 = l1 + + if u1 is None and u2 is None: + return True + elif u1 is None: + u1 = u2 + elif u2 is None: + u2 = u1 + + return l1 <= u2 and l2 <= u1 + + out: List[Range[Any]] = [] + for i in range(length): + r = self.make_Range((Range, spec[1]), **kwargs) + if r.isempty: + continue + for r2 in out: + if overlap(r, r2): + insert = False + break + else: + insert = True + if insert: + out.append(r) # alternatively, we could merge + + return spec[0](sorted(out)) + + def example_Multirange(self, spec): + return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0) + + def make_Int4Multirange(self, spec): + return self.make_Multirange((spec, Int4)) + + def make_Int8Multirange(self, spec): + return self.make_Multirange((spec, Int8)) + + def make_NumericMultirange(self, spec): + return self.make_Multirange((spec, Decimal)) + + def make_DateMultirange(self, spec): + return self.make_Multirange((spec, dt.date)) + + def make_TimestampMultirange(self, spec): + return self.make_Multirange((spec, (dt.datetime, False))) + + def make_TimestamptzMultirange(self, spec): + return self.make_Multirange((spec, (dt.datetime, True))) + + def match_Multirange(self, spec, got, want): + assert len(got) == len(want) + for ig, iw in zip(got, want): + self.match_Range(spec, ig, iw) + + def match_Int4Multirange(self, spec, got, want): + return self.match_Multirange((spec, Int4), got, want) + + def match_Int8Multirange(self, spec, got, want): + return self.match_Multirange((spec, Int8), got, want) + + def match_NumericMultirange(self, spec, got, want): + return self.match_Multirange((spec, Decimal), got, want) + + def match_DateMultirange(self, spec, got, want): + return self.match_Multirange((spec, dt.date), got, want) + + def match_TimestampMultirange(self, spec, got, want): + return self.match_Multirange((spec, (dt.datetime, False)), got, want) + + def match_TimestamptzMultirange(self, spec, got, want): + return self.match_Multirange((spec, (dt.datetime, True)), got, want) + + def schema_NoneType(self, cls): + return None + + def make_NoneType(self, spec): + return None + + def make_Oid(self, spec): + return spec(randrange(1 << 32)) + + def schema_Range(self, cls): + subtypes = [ + Decimal, + Int4, + Int8, + dt.date, + (dt.datetime, True), + (dt.datetime, False), + ] + + return (cls, choice(subtypes)) + + def make_Range(self, spec, empty_chance=0.02, no_bound_chance=0.05): + # TODO: drop format check after fixing binary dumping of empty ranges + # (an array starting with an empty range will get the wrong type currently) + if ( + random() < empty_chance + and spec[0] is Range + and self.format == PyFormat.TEXT + ): + return spec[0](empty=True) + + while True: + bounds: List[Union[Any, None]] = [] + while len(bounds) < 2: + if random() < no_bound_chance: + bounds.append(None) + continue + + val = self.make(spec[1]) + # NaN are allowed in a range, but comparison in Python get tricky. + if spec[1] is Decimal and val.is_nan(): + continue + + bounds.append(val) + + if bounds[0] is not None and bounds[1] is not None: + if bounds[0] == bounds[1]: + # It would come out empty + continue + + if bounds[0] > bounds[1]: + bounds.reverse() + + # avoid generating ranges with no type info if dumping in binary + # TODO: lift this limitation after test_copy_in_empty xfail is fixed + if spec[0] is Range and self.format == PyFormat.BINARY: + if bounds[0] is bounds[1] is None: + continue + + break + + r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])")) + return r + + def example_Range(self, spec): + return self.make_Range(spec, empty_chance=0, no_bound_chance=0) + + def make_Int4Range(self, spec): + return self.make_Range((spec, Int4)) + + def make_Int8Range(self, spec): + return self.make_Range((spec, Int8)) + + def make_NumericRange(self, spec): + return self.make_Range((spec, Decimal)) + + def make_DateRange(self, spec): + return self.make_Range((spec, dt.date)) + + def make_TimestampRange(self, spec): + return self.make_Range((spec, (dt.datetime, False))) + + def make_TimestamptzRange(self, spec): + return self.make_Range((spec, (dt.datetime, True))) + + def match_Range(self, spec, got, want): + # normalise the bounds of unbounded ranges + if want.lower is None and want.lower_inc: + want = type(want)(want.lower, want.upper, "(" + want.bounds[1]) + if want.upper is None and want.upper_inc: + want = type(want)(want.lower, want.upper, want.bounds[0] + ")") + + # Normalise discrete ranges + unit: Union[dt.timedelta, int, None] + if spec[1] is dt.date: + unit = dt.timedelta(days=1) + elif type(spec[1]) is type and issubclass(spec[1], int): + unit = 1 + else: + unit = None + + if unit is not None: + if want.lower is not None and not want.lower_inc: + want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1]) + if want.upper_inc: + want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")") + + if spec[1] == (dt.datetime, True) and not want.isempty: + # work around https://bugs.python.org/issue45347 + def fix_dt(x): + return x.astimezone(dt.timezone.utc) if x is not None else None + + def fix_range(r): + return type(r)(fix_dt(r.lower), fix_dt(r.upper), r.bounds) + + want = fix_range(want) + got = fix_range(got) + + assert got == want + + def match_Int4Range(self, spec, got, want): + return self.match_Range((spec, Int4), got, want) + + def match_Int8Range(self, spec, got, want): + return self.match_Range((spec, Int8), got, want) + + def match_NumericRange(self, spec, got, want): + return self.match_Range((spec, Decimal), got, want) + + def match_DateRange(self, spec, got, want): + return self.match_Range((spec, dt.date), got, want) + + def match_TimestampRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, False)), got, want) + + def match_TimestamptzRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, True)), got, want) + + def make_str(self, spec, length=0): + if not length: + length = randrange(self.str_max_length) + + rv: List[int] = [] + while len(rv) < length: + c = randrange(1, 128) if random() < 0.5 else randrange(1, 0x110000) + if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF): + rv.append(c) + + return "".join(map(chr, rv)) + + def schema_time(self, cls): + # Choose timezone yes/no + return (cls, choice([True, False])) + + def make_time(self, spec): + val = randrange(24 * 60 * 60 * 1_000_000) + val, ms = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + tz = self._make_tz(spec) if spec[1] else None + return dt.time(h, m, s, ms, tz) + + CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239) + + def make_timedelta(self, spec): + if self.conn.info.vendor == "CockroachDB": + rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX] + else: + rng = [dt.timedelta.min, dt.timedelta.max] + + return choice(rng) * random() + + def schema_tuple(self, cls): + # TODO: this is a complicated matter as it would involve creating + # temporary composite types. + # length = randrange(1, self.tuple_max_length) + # return (cls, self.make_random_schema(ncols=length)) + return None + + def make_tuple(self, spec): + return tuple(self.make(s) for s in spec[1]) + + def match_tuple(self, spec, got, want): + assert len(got) == len(want) == len(spec[1]) + for g, w, s in zip(got, want, spec): + if g is None or w is None: + assert g is w + else: + m = self.get_matcher(s) + m(s, g, w) + + def make_UUID(self, spec): + return UUID(bytes=bytes([randrange(256) for i in range(16)])) + + def _make_json(self, container_chance=0.66): + rec_types = [list, dict] + scal_types = [type(None), int, JsonFloat, bool, str] + if random() < container_chance: + cls = choice(rec_types) + if cls is list: + return [ + self._make_json(container_chance=container_chance / 2.0) + for i in range(randrange(self.json_max_length)) + ] + elif cls is dict: + return { + self.make_str(str, 15): self._make_json( + container_chance=container_chance / 2.0 + ) + for i in range(randrange(self.json_max_length)) + } + else: + assert False, f"unknown rec type: {cls}" + + else: + cls = choice(scal_types) # type: ignore[assignment] + return self.make(cls) + + def _make_tz(self, spec): + minutes = randrange(-12 * 60, 12 * 60 + 1) + return dt.timezone(dt.timedelta(minutes=minutes)) + + +class JsonFloat: + pass + + +def deep_import(name): + parts = Deque(name.split(".")) + seen = [] + if not parts: + raise ValueError("name must be a dot-separated name") + + seen.append(parts.popleft()) + thing = importlib.import_module(seen[-1]) + while parts: + attr = parts.popleft() + seen.append(attr) + + if hasattr(thing, attr): + thing = getattr(thing, attr) + else: + thing = importlib.import_module(".".join(seen)) + + return thing diff --git a/tests/fix_mypy.py b/tests/fix_mypy.py new file mode 100644 index 0000000..b860a32 --- /dev/null +++ b/tests/fix_mypy.py @@ -0,0 +1,54 @@ +import re +import subprocess as sp + +import pytest + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "mypy: the test uses mypy (the marker is set automatically" + " on tests using the fixture)", + ) + + +def pytest_collection_modifyitems(items): + for item in items: + if "mypy" in item.fixturenames: + # add a mypy tag so we can address these tests only + item.add_marker(pytest.mark.mypy) + + # All the tests using mypy are slow + item.add_marker(pytest.mark.slow) + + +@pytest.fixture(scope="session") +def mypy(tmp_path_factory): + cache_dir = tmp_path_factory.mktemp(basename="mypy_cache") + src_dir = tmp_path_factory.mktemp("source") + + class MypyRunner: + def run_on_file(self, filename): + cmdline = f""" + mypy + --strict + --show-error-codes --no-color-output --no-error-summary + --config-file= --cache-dir={cache_dir} + """.split() + cmdline.append(filename) + return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT) + + def run_on_source(self, source): + fn = src_dir / "tmp.py" + with fn.open("w") as f: + f.write(source) + + return self.run_on_file(str(fn)) + + def get_revealed(self, line): + """return the type from an output of reveal_type""" + return re.sub( + r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line + ).replace("*", "") + + return MypyRunner() diff --git a/tests/fix_pq.py b/tests/fix_pq.py new file mode 100644 index 0000000..6811a26 --- /dev/null +++ b/tests/fix_pq.py @@ -0,0 +1,141 @@ +import os +import sys +import ctypes +from typing import Iterator, List, NamedTuple +from tempfile import TemporaryFile + +import pytest + +from .utils import check_libpq_version + +try: + from psycopg import pq +except ImportError: + pq = None # type: ignore + + +def pytest_report_header(config): + try: + from psycopg import pq + except ImportError: + return [] + + return [ + f"libpq wrapper implementation: {pq.__impl__}", + f"libpq used: {pq.version()}", + f"libpq compiled: {pq.__build_version__}", + ] + + +def pytest_configure(config): + # register libpq marker + config.addinivalue_line( + "markers", + "libpq(version_expr): run the test only with matching libpq" + " (e.g. '>= 10', '< 9.6')", + ) + + +def pytest_runtest_setup(item): + for m in item.iter_markers(name="libpq"): + assert len(m.args) == 1 + msg = check_libpq_version(pq.version(), m.args[0]) + if msg: + pytest.skip(msg) + + +@pytest.fixture +def libpq(): + """Return a ctypes wrapper to access the libpq.""" + try: + from psycopg.pq.misc import find_libpq_full_path + + # Not available when testing the binary package + libname = find_libpq_full_path() + assert libname, "libpq libname not found" + return ctypes.pydll.LoadLibrary(libname) + except Exception as e: + if pq.__impl__ == "binary": + pytest.skip(f"can't load libpq for testing: {e}") + else: + raise + + +@pytest.fixture +def setpgenv(monkeypatch): + """Replace the PG* env vars with the vars provided.""" + + def setpgenv_(env): + ks = [k for k in os.environ if k.startswith("PG")] + for k in ks: + monkeypatch.delenv(k) + + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) + + return setpgenv_ + + +@pytest.fixture +def trace(libpq): + pqver = pq.__build_version__ + if pqver < 140000: + pytest.skip(f"trace not available on libpq {pqver}") + if sys.platform != "linux": + pytest.skip(f"trace not available on {sys.platform}") + + yield Tracer() + + +class Tracer: + def trace(self, conn): + pgconn: "pq.abc.PGconn" + + if hasattr(conn, "exec_"): + pgconn = conn + elif hasattr(conn, "cursor"): + pgconn = conn.pgconn + else: + raise Exception() + + return TraceLog(pgconn) + + +class TraceLog: + def __init__(self, pgconn: "pq.abc.PGconn"): + self.pgconn = pgconn + self.tempfile = TemporaryFile(buffering=0) + pgconn.trace(self.tempfile.fileno()) + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS) + + def __del__(self): + if self.pgconn.status == pq.ConnStatus.OK: + self.pgconn.untrace() + self.tempfile.close() + + def __iter__(self) -> "Iterator[TraceEntry]": + self.tempfile.seek(0) + data = self.tempfile.read() + for entry in self._parse_entries(data): + yield entry + + def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]": + for line in data.splitlines(): + direction, length, type, *content = line.split(b"\t") + yield TraceEntry( + direction=direction.decode(), + length=int(length.decode()), + type=type.decode(), + # Note: the items encoding is not very solid: no escaped + # backslash, no escaped quotes. + # At the moment we don't need a proper parser. + content=[content[0]] if content else [], + ) + + +class TraceEntry(NamedTuple): + direction: str + length: int + type: str + content: List[bytes] diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py new file mode 100644 index 0000000..e50f5ec --- /dev/null +++ b/tests/fix_proxy.py @@ -0,0 +1,127 @@ +import os +import time +import socket +import logging +import subprocess as sp +from shutil import which + +import pytest + +import psycopg +from psycopg import conninfo + + +def pytest_collection_modifyitems(items): + for item in items: + # TODO: there is a race condition on macOS and Windows in the CI: + # listen returns before really listening and tests based on 'deaf_port' + # fail 50% of the times. Just add the 'proxy' mark on these tests + # because they are already skipped in the CI. + if "proxy" in item.fixturenames or "deaf_port" in item.fixturenames: + item.add_marker(pytest.mark.proxy) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "proxy: the test uses pproxy (the marker is set automatically" + " on tests using the fixture)", + ) + + +@pytest.fixture +def proxy(dsn): + """Return a proxy to the --test-dsn database""" + p = Proxy(dsn) + yield p + p.stop() + + +@pytest.fixture +def deaf_port(dsn): + """Return a port number with a socket open but not answering""" + with socket.socket(socket.AF_INET) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + s.listen(0) + yield port + + +class Proxy: + """ + Proxy a Postgres service for testing purpose. + + Allow to lose connectivity and restart it using stop/start. + """ + + def __init__(self, server_dsn): + cdict = conninfo.conninfo_to_dict(server_dsn) + + # Get server params + host = cdict.get("host") or os.environ.get("PGHOST") + self.server_host = host if host and not host.startswith("/") else "localhost" + self.server_port = cdict.get("port", "5432") + + # Get client params + self.client_host = "localhost" + self.client_port = self._get_random_port() + + # Make a connection string to the proxy + cdict["host"] = self.client_host + cdict["port"] = self.client_port + cdict["sslmode"] = "disable" # not supported by the proxy + self.client_dsn = conninfo.make_conninfo(**cdict) + + # The running proxy process + self.proc = None + + def start(self): + if self.proc: + logging.info("proxy already started") + return + + logging.info("starting proxy") + pproxy = which("pproxy") + if not pproxy: + raise ValueError("pproxy program not found") + cmdline = [pproxy, "--reuse"] + cmdline.extend(["-l", f"tunnel://:{self.client_port}"]) + cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"]) + + self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL) + logging.info("proxy started") + self._wait_listen() + + # verify that the proxy works + try: + with psycopg.connect(self.client_dsn): + pass + except Exception as e: + pytest.fail(f"failed to create a working proxy: {e}") + + def stop(self): + if not self.proc: + return + + logging.info("stopping proxy") + self.proc.terminate() + self.proc.wait() + logging.info("proxy stopped") + self.proc = None + + @classmethod + def _get_random_port(cls): + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + def _wait_listen(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + for i in range(20): + if 0 == sock.connect_ex((self.client_host, self.client_port)): + break + time.sleep(0.1) + else: + raise ValueError("the proxy didn't start listening in time") + + logging.info("proxy listening") diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py new file mode 100644 index 0000000..80e0c62 --- /dev/null +++ b/tests/fix_psycopg.py @@ -0,0 +1,98 @@ +from copy import deepcopy + +import pytest + + +@pytest.fixture +def global_adapters(): + """Restore the global adapters after a test has changed them.""" + from psycopg import adapters + + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +@pytest.fixture +@pytest.mark.crdb_skip("2-phase commit") +def tpc(svcconn): + tpc = Tpc(svcconn) + tpc.check_tpc() + tpc.clear_test_xacts() + tpc.make_test_table() + yield tpc + tpc.clear_test_xacts() + + +class Tpc: + """Helper object to test two-phase transactions""" + + def __init__(self, conn): + assert conn.autocommit + self.conn = conn + + def check_tpc(self): + from .fix_crdb import is_crdb, crdb_skip_message + + if is_crdb(self.conn): + pytest.skip(crdb_skip_message("2-phase commit")) + + val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0]) + if not val: + pytest.skip("prepared transactions disabled in the database") + + def clear_test_xacts(self): + """Rollback all the prepared transaction in the testing db.""" + from psycopg import sql + + cur = self.conn.execute( + "select gid from pg_prepared_xacts where database = %s", + (self.conn.info.dbname,), + ) + gids = [r[0] for r in cur] + for gid in gids: + self.conn.execute(sql.SQL("rollback prepared {}").format(gid)) + + def make_test_table(self): + self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)") + self.conn.execute("TRUNCATE test_tpc") + + def count_xacts(self): + """Return the number of prepared xacts currently in the test db.""" + cur = self.conn.execute( + """ + select count(*) from pg_prepared_xacts + where database = %s""", + (self.conn.info.dbname,), + ) + return cur.fetchone()[0] + + def count_test_records(self): + """Return the number of records in the test table.""" + cur = self.conn.execute("select count(*) from test_tpc") + return cur.fetchone()[0] + + +@pytest.fixture(scope="module") +def generators(): + """Return the 'generators' module for selected psycopg implementation.""" + from psycopg import pq + + if pq.__impl__ == "c": + from psycopg._cmodule import _psycopg + + return _psycopg + else: + import psycopg.generators + + return psycopg.generators diff --git a/tests/pool/__init__.py b/tests/pool/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/pool/__init__.py diff --git a/tests/pool/fix_pool.py b/tests/pool/fix_pool.py new file mode 100644 index 0000000..12e4f39 --- /dev/null +++ b/tests/pool/fix_pool.py @@ -0,0 +1,12 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "pool: test related to the psycopg_pool package") + + +def pytest_collection_modifyitems(items): + # Add the pool markers to all the tests in the pool package + for item in items: + if "/pool/" in item.nodeid: + item.add_marker(pytest.mark.pool) diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py new file mode 100644 index 0000000..c0e8060 --- /dev/null +++ b/tests/pool/test_null_pool.py @@ -0,0 +1,896 @@ +import logging +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus + +from .test_pool import delay_connection, ensure_waiting + +try: + from psycopg_pool import NullConnectionPool + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +def test_defaults(dsn): + with NullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +def test_min_size_max_size(dsn): + with NullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + NullConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with NullConnectionPool(dsn, connection_class=MyConn) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_no_pool_at_all(dsn): + with NullConnectionPool(dsn, max_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +def test_context(dsn): + with NullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.1) + + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.4) + + +def test_wait_closed(dsn): + with NullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + p.wait(0.2) + + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with NullConnectionPool(dsn, configure=configure) as p: + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 3 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + pids = [] + + def worker(): + with p.connection() as conn: + assert resets == 1 + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + assert resets == 0 + conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + t.join() + p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + conn.execute("select 1") + pids.append(conn.info.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +def test_no_queue_timeout(deaf_port): + with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p: + with pytest.raises(PoolTimeout): + with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with NullConnectionPool(dsn, max_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with NullConnectionPool(dsn, max_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with NullConnectionPool(dsn, max_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + assert not conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + t = Thread(target=worker) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + t = Thread(target=worker, args=(p,)) + t.start() + ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with NullConnectionPool(dsn) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with NullConnectionPool(dsn) as p1: + with NullConnectionPool(dsn) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = NullConnectionPool(dsn) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = NullConnectionPool(dsn) + + with p.connection() as conn: + pass + assert conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = NullConnectionPool(dsn, max_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = NullConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +def test_reopen(dsn): + p = NullConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +def test_bad_resize(dsn, min_size, max_size): + with NullConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.1) + + ts = [] + with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + for i in range(5): + ts.append(Thread(target=worker, args=(p,))) + ts[-1].start() + + for t in ts: + t.join() + + assert pids[0] == pids[1] != pids[4], pids + + +def test_check(dsn): + with NullConnectionPool(dsn) as p: + # No-op + p.check() + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with NullConnectionPool(dsn, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + with NullConnectionPool(dsn, max_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with NullConnectionPool(proxy.client_dsn, max_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py new file mode 100644 index 0000000..23a1a52 --- /dev/null +++ b/tests/pool/test_null_pool_async.py @@ -0,0 +1,844 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task +from .test_pool_async import delay_connection, ensure_waiting + +pytestmark = [pytest.mark.asyncio] + +try: + from psycopg_pool import AsyncNullConnectionPool # noqa: F401 + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +async def test_defaults(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +async def test_min_size_max_size(dsn): + async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + AsyncNullConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_no_pool_at_all(dsn): + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid not in (pid1, pid2) + + +async def test_context(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.1) + + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.4) + + +async def test_wait_closed(dsn): + async with AsyncNullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await p.wait(0.2) + + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 3 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + pids = [] + + async def worker(): + async with p.connection() as conn: + assert resets == 1 + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + await p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + await ensure_waiting(p) + + await conn.execute("select 1") + pids.append(conn.info.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +async def test_no_queue_timeout(deaf_port): + async with AsyncNullConnectionPool( + kwargs={"host": "localhost", "port": deaf_port} + ) as p: + with pytest.raises(PoolTimeout): + async with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + await ensure_waiting(p) + + pids.append(conn.info.backend_pid) + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + t = create_task(worker()) + await ensure_waiting(p) + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pids.append(conn.info.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = AsyncNullConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with AsyncNullConnectionPool(dsn) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with AsyncNullConnectionPool(dsn) as p1: + async with AsyncNullConnectionPool(dsn) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = AsyncNullConnectionPool(dsn) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = AsyncNullConnectionPool(dsn) + + async with p.connection() as conn: + pass + assert conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = AsyncNullConnectionPool(dsn, max_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed): + await p.getconn() + + with pytest.raises(PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = AsyncNullConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +async def test_reopen(dsn): + p = AsyncNullConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with AsyncNullConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + pids: List[int] = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.1) + + async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + ts = [create_task(worker()) for i in range(5)] + await asyncio.gather(*ts) + + assert pids[0] == pids[1] != pids[4], pids + + +async def test_check(dsn): + # no.op + async with AsyncNullConnectionPool(dsn) as p: + await p.check() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with AsyncNullConnectionPool(dsn, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + async with AsyncNullConnectionPool(dsn, max_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py new file mode 100644 index 0000000..30c790b --- /dev/null +++ b/tests/pool/test_pool.py @@ -0,0 +1,1265 @@ +import logging +import weakref +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +def test_package_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg_pool import __version__ +assert __version__ +""" + ) + assert not cp.stdout + + +def test_defaults(dsn): + with pool.ConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +def test_min_size_max_size(dsn, min_size, max_size): + with pool.ConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.ConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with pool.ConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p: + with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +def test_its_really_a_pool(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +def test_context(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +def test_connection_not_lost(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + def add_time(self, conn): + times.append(time() - t0) + add_orig(self, conn) + + add_orig = pool.ConnectionPool._add_to_pool + monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + with pool.ConnectionPool(dsn, min_size=5, num_workers=2) as p: + p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.wait(0.5) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=2) as p: + p.wait(0.3) + p.wait(0.0001) # idempotent + + +def test_wait_closed(dsn): + with pool.ConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + p.wait(0.2) + + with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + p.wait() + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + assert resets == 0 + conn.execute("set timezone to '+2:00'") + + p.wait() + assert resets == 1 + + with p.connection() as conn: + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + + p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(dsn, min_size=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.info.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue(dsn): + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout(dsn): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_queue_timeout_override(dsn): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +def test_broken_reconnect(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + with p.connection() as conn: + pid1 = conn.info.backend_pid + conn.close() + + with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + assert not conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(conn_cls, dsn): + with pool.ConnectionPool(dsn, min_size=1) as p: + conn = conn_cls.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with pool.ConnectionPool(dsn, min_size=1) as p1: + with pool.ConnectionPool(dsn, min_size=1) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +def test_del_no_warning(dsn, recwarn): + p = pool.ConnectionPool(dsn, min_size=2) + with p.connection() as conn: + conn.execute("select 1") + + p.wait() + ref = weakref.ref(p) + del p + assert not ref() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = pool.ConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = pool.ConnectionPool(dsn, min_size=1) + + with p.connection() as conn: + pass + assert not conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = pool.ConnectionPool(dsn, min_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + ensure_waiting(p) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(pool.PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = pool.ConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = pool.ConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.3) + finally: + p.close() + + p = pool.ConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + p.open(wait=True, timeout=0.5) + finally: + p.close() + + +@pytest.mark.slow +@pytest.mark.timing +def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.3) + + with pool.ConnectionPool(dsn, min_size=4, num_workers=1) as p: + p.open(wait=True, timeout=0.5) + + +def test_reopen(dsn): + p = pool.ConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + def worker(n): + t0 = time() + with p.connection() as conn: + conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p: + p.wait(1.0) + results: List[Tuple[int, float]] = [] + + ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool import ShrinkPool + + results: List[Tuple[int, int]] = [] + + def run_hacked(self, pool): + n0 = pool._nconns + orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.1)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + with pool.ConnectionPool(proxy.client_dsn, min_size=1) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + sleep(1.0) + proxy.start() + p.wait() + + with p.connection() as conn: + conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + with pool.ConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + with p.connection() as conn: + conn.execute("select 1") + + t0 = time() + sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + with p.connection() as conn: + conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + assert ev.wait(timeout=2) + + with pytest.raises(pool.PoolTimeout): + with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + assert ev.wait(timeout=2) + + proxy.start() + + with p.connection(timeout=2) as conn: + conn.execute("select 1") + + p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +def test_refill_on_check(proxy): + proxy.start() + ev = Event() + + def failed(pool): + ev.set() + + with pool.ConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + p.check() + assert ev.wait(timeout=2) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + p.check() + p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +def test_uniform_use(dsn): + with pool.ConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + with p.connection() as conn: + sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +def test_resize(dsn): + def sampler(): + sleep(0.05) # ensure sampling happens after shrink check + while True: + sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + def client(t): + with p.connection() as conn: + conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = Thread(target=sampler) + s.start() + + sleep(0.3) + c = Thread(target=client, args=(0.4,)) + c.start() + + sleep(0.2) + p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + sleep(0.4) + p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + sleep(0.6) + + s.join() + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +def test_bad_resize(dsn, min_size, max_size): + with pool.ConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +def test_jitter(): + rnds = [pool.ConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +def test_max_lifetime(dsn): + with pool.ConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + sleep(0.1) + pids = [] + for i in range(5): + with p.connection() as conn: + pids.append(conn.info.backend_pid) + sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + with pool.ConnectionPool(dsn, min_size=4) as p: + p.wait(1.0) + with p.connection() as conn: + pid = conn.info.backend_pid + + p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + conn.close() + + assert len(caplog.records) == 0 + p.check() + assert len(caplog.records) == 1 + p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +def test_check_idle(dsn): + with pool.ConnectionPool(dsn, min_size=2) as p: + p.wait(1.0) + p.check() + with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with pool.ConnectionPool(dsn, min_size=2, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + with pool.ConnectionPool(dsn, min_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with pool.ConnectionPool(proxy.client_dsn, min_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 600 <= stats["connections_ms"] < 1200 + + proxy.stop() + p.check() + sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + def worker(): + with p.connection(): + sleep(0.002) + + with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p: + p.wait() + + ts = [Thread(target=worker) for i in range(50)] + for t in ts: + t.start() + for t in ts: + t.join() + p.wait() + + assert len(p._pool) < 7 + + +def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + with pool.ConnectionPool(dsn, min_size=4, open=True) as p: + try: + p.wait(timeout=2) + finally: + print(p.get_stats()) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + def connect_delay(*args, **kwargs): + t0 = time() + rv = connect_orig(*args, **kwargs) + t1 = time() + sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.Connection.connect + monkeypatch.setattr(psycopg.Connection, "connect", connect_delay) + + +def ensure_waiting(p, num=1): + """ + Wait until there are at least *num* clients waiting in the queue. + """ + while len(p._waiting) < num: + sleep(0) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py new file mode 100644 index 0000000..286a775 --- /dev/null +++ b/tests/pool/test_pool_async.py @@ -0,0 +1,1198 @@ +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task, Counter + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio] + + +async def test_defaults(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 4 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)]) +async def test_min_size_max_size(dsn, min_size, max_size): + async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p: + assert p.min_size == min_size + assert p.max_size == max_size if max_size is not None else min_size + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + pool.AsyncConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with pool.AsyncConnectionPool( + dsn, kwargs={"autocommit": True}, min_size=1 + ) as p: + async with p.connection() as conn: + assert conn.autocommit + + +@pytest.mark.crdb_skip("backend pid") +async def test_its_really_a_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + async with p.connection() as conn: + assert conn.info.backend_pid in (pid1, pid2) + + +async def test_context(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.crdb_skip("backend pid") +async def test_connection_not_lost(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + with pytest.raises(ZeroDivisionError): + async with p.connection() as conn: + pid = conn.info.backend_pid + 1 / 0 + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + + +@pytest.mark.slow +@pytest.mark.timing +async def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + + async def add_time(self, conn): + times.append(time() - t0) + await add_orig(self, conn) + + add_orig = pool.AsyncConnectionPool._add_to_pool + monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time) + + times: List[float] = [] + t0 = time() + + async with pool.AsyncConnectionPool(dsn, min_size=5, num_workers=2) as p: + await p.wait(1.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.wait(0.5) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=2) as p: + await p.wait(0.3) + await p.wait(0.0001) # idempotent + + +async def test_wait_closed(dsn): + async with pool.AsyncConnectionPool(dsn) as p: + pass + + with pytest.raises(pool.PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await p.wait(0.2) + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=1, num_workers=1 + ) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + await p.wait(timeout=1.0) + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p: + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + + await p.wait() + assert resets == 1 + + async with p.connection() as conn: + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + + await p.wait() + assert resets == 2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(dsn, min_size=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.info.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.info.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue(dsn): + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except pool.TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], pool.TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout(dsn): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_queue_timeout_override(dsn): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.2)") + pid = conn.info.backend_pid + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.crdb_skip("backend pid") +async def test_broken_reconnect(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + async with p.connection() as conn: + pid1 = conn.info.backend_pid + await conn.close() + + async with p.connection() as conn2: + pid2 = conn2.info.backend_pid + + assert pid1 != pid2 + + +@pytest.mark.crdb_skip("backend pid") +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + await conn.execute("create table test_intrans_rollback ()") + assert conn.info.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + cur = await conn2.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid == pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +@pytest.mark.crdb_skip("backend pid") +@pytest.mark.crdb_skip("copy") +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + pid = conn.info.backend_pid + conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout") + assert conn.info.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +@pytest.mark.crdb_skip("backend pid") +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await p.getconn() + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.info.backend_pid + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.info.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.info.backend_pid != pid + assert conn2.info.transaction_status == TransactionStatus.IDLE + + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(aconn_cls, dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p: + conn = await aconn_cls.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1) as p1: + async with pool.AsyncConnectionPool(dsn, min_size=1) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = pool.AsyncConnectionPool(dsn, min_size=1) + + async with p.connection() as conn: + pass + assert not conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except pool.PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = pool.AsyncConnectionPool(dsn, min_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + await ensure_waiting(p) + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(pool.PoolClosed): + await p.getconn() + + with pytest.raises(pool.PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(pool.PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = pool.AsyncConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = pool.AsyncConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.3) + finally: + await p.close() + + p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False) + try: + await p.open(wait=True, timeout=0.5) + finally: + await p.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_open_as_wait(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.3) + + async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p: + await p.open(wait=True, timeout=0.5) + + +async def test_reopen(dsn): + p = pool.AsyncConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize( + "min_size, want_times", + [ + (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), + (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + ], +) +async def test_grow(dsn, monkeypatch, min_size, want_times): + delay_connection(monkeypatch, 0.1) + + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1 from pg_sleep(0.25)") + t1 = time() + results.append((n, t1 - t0)) + + async with pool.AsyncConnectionPool( + dsn, min_size=min_size, max_size=4, num_workers=3 + ) as p: + await p.wait(1.0) + ts = [] + results: List[Tuple[int, float]] = [] + + ts = [create_task(worker(i)) for i in range(len(want_times))] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +@pytest.mark.timing +async def test_shrink(dsn, monkeypatch): + + from psycopg_pool.pool_async import ShrinkPool + + results: List[Tuple[int, int]] = [] + + async def run_hacked(self, pool): + n0 = pool._nconns + await orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) + + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.1)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p: + await p.wait(5.0) + assert p.max_idle == 0.2 + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + await asyncio.sleep(1) + + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +async def test_reconnect(proxy, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + caplog.clear() + proxy.start() + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=1) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + await asyncio.sleep(1.0) + proxy.start() + await p.wait() + + async with p.connection() as conn: + await conn.execute("select 1") + + assert "BAD" in caplog.messages[0] + times = [rec.created for rec in caplog.records] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, + name="this-one", + min_size=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) as p: + await p.wait(2.0) + proxy.stop() + + with pytest.raises(psycopg.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + t0 = time() + await asyncio.sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + + +@pytest.mark.slow +async def test_reconnect_after_grow_failed(proxy): + # Retry reconnection after a failed connection attempt has put the pool + # in grow mode. See issue #370. + proxy.stop() + + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + await asyncio.wait_for(ev.wait(), 2.0) + + with pytest.raises(pool.PoolTimeout): + async with p.connection(timeout=0.5) as conn: + pass + + ev.clear() + await asyncio.wait_for(ev.wait(), 2.0) + + proxy.start() + + async with p.connection(timeout=2) as conn: + await conn.execute("select 1") + + await p.wait(timeout=3.0) + assert len(p._pool) == p.min_size == 4 + + +@pytest.mark.slow +async def test_refill_on_check(proxy): + proxy.start() + ev = asyncio.Event() + + def failed(pool): + ev.set() + + async with pool.AsyncConnectionPool( + proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed + ) as p: + # The pool is full + await p.wait(timeout=2) + + # Break all the connection + proxy.stop() + + # Checking the pool will empty it + await p.check() + await asyncio.wait_for(ev.wait(), 2.0) + assert len(p._pool) == 0 + + # Allow to connect again + proxy.start() + + # Make sure that check has refilled the pool + await p.check() + await p.wait(timeout=2) + assert len(p._pool) == 4 + + +@pytest.mark.slow +async def test_uniform_use(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + counts = Counter[int]() + for i in range(8): + async with p.connection() as conn: + await asyncio.sleep(0.1) + counts[id(conn)] += 1 + + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_resize(dsn): + async def sampler(): + await asyncio.sleep(0.05) # ensure sampling happens after shrink check + while True: + await asyncio.sleep(0.2) + if p.closed: + break + size.append(len(p._pool)) + + async def client(t): + async with p.connection() as conn: + await conn.execute("select pg_sleep(%s)", [t]) + + size: List[int] = [] + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p: + s = create_task(sampler()) + + await asyncio.sleep(0.3) + + c = create_task(client(0.4)) + + await asyncio.sleep(0.2) + await p.resize(4) + assert p.min_size == 4 + assert p.max_size == 4 + + await asyncio.sleep(0.4) + await p.resize(2) + assert p.min_size == 2 + assert p.max_size == 2 + + await asyncio.sleep(0.6) + + await asyncio.gather(s, c) + assert size == [2, 1, 3, 4, 3, 2, 2] + + +@pytest.mark.parametrize("min_size, max_size", [(0, 0), (-1, None), (4, 2)]) +async def test_bad_resize(dsn, min_size, max_size): + async with pool.AsyncConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +async def test_jitter(): + rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)] + assert 27 <= min(rnds) <= 28 + assert 35 < max(rnds) < 36 + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("backend pid") +async def test_max_lifetime(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: + await asyncio.sleep(0.1) + pids = [] + for i in range(5): + async with p.connection() as conn: + pids.append(conn.info.backend_pid) + await asyncio.sleep(0.2) + + assert pids[0] == pids[1] != pids[4], pids + + +@pytest.mark.crdb_skip("backend pid") +async def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + async with pool.AsyncConnectionPool(dsn, min_size=4) as p: + await p.wait(1.0) + async with p.connection() as conn: + pid = conn.info.backend_pid + + await p.wait(1.0) + pids = set(conn.info.backend_pid for conn in p._pool) + assert pid in pids + await conn.close() + + assert len(caplog.records) == 0 + await p.check() + assert len(caplog.records) == 1 + await p.wait(1.0) + pids2 = set(conn.info.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + +async def test_check_idle(dsn): + async with pool.AsyncConnectionPool(dsn, min_size=2) as p: + await p.wait(1.0) + await p.check() + async with p.connection() as conn: + assert conn.info.transaction_status == TransactionStatus.IDLE + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 2 + assert stats["pool_available"] == 2 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 2 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except pool.PoolTimeout: + pass + + async with pool.AsyncConnectionPool(dsn, min_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with pool.AsyncConnectionPool(proxy.client_dsn, min_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 3 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 580 <= stats["connections_ms"] < 1200 + + proxy.stop() + await p.check() + await asyncio.sleep(0.1) + stats = p.get_stats() + assert stats["connections_num"] > 3 + assert stats["connections_errors"] > 0 + assert stats["connections_lost"] == 3 + + +@pytest.mark.slow +async def test_spike(dsn, monkeypatch): + # Inspired to https://github.com/brettwooldridge/HikariCP/blob/dev/ + # documents/Welcome-To-The-Jungle.md + delay_connection(monkeypatch, 0.15) + + async def worker(): + async with p.connection(): + await asyncio.sleep(0.002) + + async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p: + await p.wait() + + ts = [create_task(worker()) for i in range(50)] + await asyncio.gather(*ts) + await p.wait() + + assert len(p._pool) < 7 + + +async def test_debug_deadlock(dsn): + # https://github.com/psycopg/psycopg/issues/230 + logger = logging.getLogger("psycopg") + handler = logging.StreamHandler() + old_level = logger.level + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + try: + async with pool.AsyncConnectionPool(dsn, min_size=4, open=True) as p: + await p.wait(timeout=2) + finally: + logger.removeHandler(handler) + logger.setLevel(old_level) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + + async def connect_delay(*args, **kwargs): + t0 = time() + rv = await connect_orig(*args, **kwargs) + t1 = time() + await asyncio.sleep(max(0, sec - (t1 - t0))) + return rv + + connect_orig = psycopg.AsyncConnection.connect + monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay) + + +async def ensure_waiting(p, num=1): + while len(p._waiting) < num: + await asyncio.sleep(0) diff --git a/tests/pool/test_pool_async_noasyncio.py b/tests/pool/test_pool_async_noasyncio.py new file mode 100644 index 0000000..f6e34e4 --- /dev/null +++ b/tests/pool/test_pool_async_noasyncio.py @@ -0,0 +1,78 @@ +# These tests relate to AsyncConnectionPool, but are not marked asyncio +# because they rely on the pool initialization outside the asyncio loop. + +import asyncio + +import pytest + +from ..utils import gc_collect + +try: + import psycopg_pool as pool +except ImportError: + # Tests should have been skipped if the package is not available + pass + + +@pytest.mark.slow +def test_reconnect_after_max_lifetime(dsn, asyncio_run): + # See issue #219, pool created before the loop. + p = pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +@pytest.mark.slow +def test_working_created_before_loop(dsn, asyncio_run): + p = pool.AsyncNullConnectionPool(dsn, open=False) + + async def test(): + try: + await p.open() + ns = [] + for i in range(5): + async with p.connection() as conn: + cur = await conn.execute("select 1") + ns.append(await cur.fetchone()) + await asyncio.sleep(0.2) + assert len(ns) == 5 + finally: + await p.close() + + asyncio_run(asyncio.wait_for(test(), timeout=2.0)) + + +def test_cant_create_open_outside_loop(dsn): + with pytest.raises(RuntimeError): + pool.AsyncConnectionPool(dsn, open=True) + + +@pytest.fixture +def asyncio_run(recwarn): + """Fixture reuturning asyncio.run, but managing resources at exit. + + In certain runs, fd objects are leaked and the error will only be caught + downstream, by some innocent test calling gc_collect(). + """ + recwarn.clear() + try: + yield asyncio.run + finally: + gc_collect() + if recwarn: + warn = recwarn.pop(ResourceWarning) + assert "unclosed event loop" in str(warn.message) + assert not recwarn diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py new file mode 100644 index 0000000..b3d2572 --- /dev/null +++ b/tests/pool/test_sched.py @@ -0,0 +1,154 @@ +import logging +from time import time, sleep +from functools import partial +from threading import Thread + +import pytest + +try: + from psycopg_pool.sched import Scheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.timing] + + +@pytest.mark.slow +def test_sched(): + s = Scheduler() + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +def test_sched_thread(): + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = Scheduler() + t = Thread(target=s.run, daemon=True) + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + def error(): + 1 / 0 + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, None) + s.enter(0.3, partial(worker, 2)) + s.enter(0.2, error) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +def test_empty_queue_timeout(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = Thread(target=s.run) + t.start() + sleep(0.5) + s.enter(0.5, None) + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +def test_first_task_rescheduling(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + s.enter(0.4, lambda: None) + t = Thread(target=s.run) + t.start() + s.enter(0.6, None) # this task doesn't trigger a reschedule + sleep(0.1) + s.enter(0.1, lambda: None) # this triggers a reschedule + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py new file mode 100644 index 0000000..492d620 --- /dev/null +++ b/tests/pool/test_sched_async.py @@ -0,0 +1,159 @@ +import asyncio +import logging +from time import time +from functools import partial + +import pytest + +from psycopg._compat import create_task + +try: + from psycopg_pool.sched import AsyncScheduler +except ImportError: + # Tests should have been skipped if the package is not available + pass + +pytestmark = [pytest.mark.asyncio, pytest.mark.timing] + + +@pytest.mark.slow +async def test_sched(): + s = AsyncScheduler() + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + await s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +async def test_sched_task(): + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.2) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.2) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.2) + + +@pytest.mark.slow +async def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + async def error(): + 1 / 0 + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, None) + await s.enter(0.3, partial(worker, 2)) + await s.enter(0.2, error) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +async def test_empty_queue_timeout(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = create_task(s.run()) + await asyncio.sleep(0.5) + await s.enter(0.5, None) + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +async def test_first_task_rescheduling(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + setattr(s._event, "wait", wait_logging) + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + async def noop(): + pass + + await s.enter(0.4, noop) + t = create_task(s.run()) + await s.enter(0.6, None) # this task doesn't trigger a reschedule + await asyncio.sleep(0.1) + await s.enter(0.1, noop) # this triggers a reschedule + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.2), times diff --git a/tests/pq/__init__.py b/tests/pq/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/pq/__init__.py diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py new file mode 100644 index 0000000..2c3de98 --- /dev/null +++ b/tests/pq/test_async.py @@ -0,0 +1,210 @@ +from select import select + +import pytest + +import psycopg +from psycopg import pq +from psycopg.generators import execute + + +def execute_wait(pgconn): + return psycopg.waiting.wait(execute(pgconn), pgconn.socket) + + +def test_send_query(pgconn): + # This test shows how to process an async query in all its glory + pgconn.nonblocking = 1 + + # Long query to make sure we have to wait on send + pgconn.send_query( + b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;" + % (b"x" * 1_000_000) + ) + + # send loop + waited_on_send = 0 + while True: + f = pgconn.flush() + if f == 0: + break + + waited_on_send += 1 + + rl, wl, xl = select([pgconn.socket], [pgconn.socket], []) + assert not (rl and wl) + if wl: + continue # call flush again() + if rl: + pgconn.consume_input() + continue + + # TODO: this check is not reliable, it fails on travis sometimes + # assert waited_on_send + + # read loop + results = [] + while True: + pgconn.consume_input() + if pgconn.is_busy(): + select([pgconn.socket], [], []) + continue + res = pgconn.get_result() + if res is None: + break + assert res.status == pq.ExecStatus.TUPLES_OK + results.append(res) + + assert len(results) == 2 + assert results[0].nfields == 1 + assert results[0].fname(0) == b"f" + assert results[0].get_value(0, 0) == b"x" + assert results[1].nfields == 1 + assert results[1].fname(0) == b"foo" + assert results[1].get_value(0, 0) == b"1" + + +def test_send_query_compact_test(pgconn): + # Like the above test but use psycopg facilities for compactness + pgconn.send_query( + b"/* %s */ select 'x' as f from pg_sleep(0.01); select 1 as foo;" + % (b"x" * 1_000_000) + ) + results = execute_wait(pgconn) + + assert len(results) == 2 + assert results[0].nfields == 1 + assert results[0].fname(0) == b"f" + assert results[0].get_value(0, 0) == b"x" + assert results[1].nfields == 1 + assert results[1].fname(0) == b"foo" + assert results[1].get_value(0, 0) == b"1" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_query(b"select 1") + + +def test_single_row_mode(pgconn): + pgconn.send_query(b"select generate_series(1,2)") + pgconn.set_single_row_mode() + + results = execute_wait(pgconn) + assert len(results) == 3 + + res = results[0] + assert res.status == pq.ExecStatus.SINGLE_TUPLE + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"1" + + res = results[1] + assert res.status == pq.ExecStatus.SINGLE_TUPLE + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"2" + + res = results[2] + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.ntuples == 0 + + +def test_send_query_params(pgconn): + pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_query_params(b"select $1", [b"1"]) + + +def test_send_prepare(pgconn): + pgconn.send_prepare(b"prep", b"select $1::int + $2::int") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + (res,) = execute_wait(pgconn) + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_prepare(b"prep", b"select $1::int + $2::int") + with pytest.raises(psycopg.OperationalError): + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + + +def test_send_prepare_types(pgconn): + pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"prep", [b"3", b"5"]) + (res,) = execute_wait(pgconn) + assert res.get_value(0, 0) == b"8" + + +def test_send_prepared_binary_in(pgconn): + val = b"foo\00bar" + pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1]) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_send_prepared_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + pgconn.send_prepare(b"", b"select $1::bytea") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt) + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +def test_send_describe_prepared(pgconn): + pgconn.send_prepare(b"prep", b"select $1::int8 + $2::int8 as fld") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_describe_prepared(b"prep") + (res,) = execute_wait(pgconn) + assert res.nfields == 1 + assert res.ntuples == 0 + assert res.fname(0) == b"fld" + assert res.ftype(0) == 20 + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_describe_prepared(b"prep") + + +@pytest.mark.crdb_skip("server-side cursor") +def test_send_describe_portal(pgconn): + res = pgconn.exec_( + b""" + begin; + declare cur cursor for select * from generate_series(1,10) foo; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + pgconn.send_describe_portal(b"cur") + (res,) = execute_wait(pgconn) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + assert res.nfields == 1 + assert res.fname(0) == b"foo" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.send_describe_portal(b"cur") diff --git a/tests/pq/test_conninfo.py b/tests/pq/test_conninfo.py new file mode 100644 index 0000000..64d8b8f --- /dev/null +++ b/tests/pq/test_conninfo.py @@ -0,0 +1,48 @@ +import pytest + +import psycopg +from psycopg import pq + + +def test_defaults(monkeypatch): + monkeypatch.setenv("PGPORT", "15432") + defs = pq.Conninfo.get_defaults() + assert len(defs) > 20 + port = [d for d in defs if d.keyword == b"port"][0] + assert port.envvar == b"PGPORT" + assert port.compiled == b"5432" + assert port.val == b"15432" + assert port.label == b"Database-Port" + assert port.dispchar == b"" + assert port.dispsize == 6 + + +@pytest.mark.libpq(">= 10") +def test_conninfo_parse(): + infos = pq.Conninfo.parse( + b"postgresql://host1:123,host2:456/somedb" + b"?target_session_attrs=any&application_name=myapp" + ) + info = {i.keyword: i.val for i in infos if i.val is not None} + assert info[b"host"] == b"host1,host2" + assert info[b"port"] == b"123,456" + assert info[b"dbname"] == b"somedb" + assert info[b"application_name"] == b"myapp" + + +@pytest.mark.libpq("< 10") +def test_conninfo_parse_96(): + conninfo = pq.Conninfo.parse( + b"postgresql://other@localhost/otherdb" + b"?connect_timeout=10&application_name=myapp" + ) + info = {i.keyword: i.val for i in conninfo if i.val is not None} + assert info[b"host"] == b"localhost" + assert info[b"dbname"] == b"otherdb" + assert info[b"application_name"] == b"myapp" + + +def test_conninfo_parse_bad(): + with pytest.raises(psycopg.OperationalError) as e: + pq.Conninfo.parse(b"bad_conninfo=") + assert "bad_conninfo" in str(e.value) diff --git a/tests/pq/test_copy.py b/tests/pq/test_copy.py new file mode 100644 index 0000000..383d272 --- /dev/null +++ b/tests/pq/test_copy.py @@ -0,0 +1,174 @@ +import pytest + +import psycopg +from psycopg import pq + +pytestmark = pytest.mark.crdb_skip("copy") + +sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')" + +sample_tabledef = "col1 int primary key, col2 int, data text" + +sample_text = b"""\ +10\t20\thello +40\t\\N\tworld +""" + +sample_binary_value = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f + +0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n") +] + +sample_binary = b"".join(sample_binary_rows) + + +def test_put_data_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_data(b"wat") + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_data(b"wat") + + +def test_put_end_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_end() + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.put_copy_end() + + +def test_copy_out(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_( + b"select min(col1), max(col1), count(*), max(length(data)) from copy_in" + ) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + assert res.get_value(0, 1) == b"199" + assert res.get_value(0, 2) == b"200" + assert res.get_value(0, 3) == b"199" + + +def test_copy_out_err(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\thardly a number\tnope +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end() + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"hardly a number" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def test_copy_out_error_end(pgconn): + ensure_table(pgconn, sample_tabledef) + res = pgconn.exec_(b"copy copy_in from stdin") + assert res.status == pq.ExecStatus.COPY_IN + + for i in range(10): + data = [] + for j in range(20): + data.append( + f"""\ +{i * 20 + j}\t{j}\t{'X' * (i * 20 + j)} +""" + ) + rv = pgconn.put_copy_data("".join(data).encode("ascii")) + assert rv > 0 + + rv = pgconn.put_copy_end(b"nuttengoggenio") + assert rv > 0 + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"nuttengoggenio" in res.error_message + + res = pgconn.exec_(b"select count(*) from copy_in") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"0" + + +def test_get_data_no_copy(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.get_copy_data(0) + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.get_copy_data(0) + + +@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY]) +def test_copy_out_read(pgconn, format): + stmt = f"copy ({sample_values}) to stdout (format {format.name})" + res = pgconn.exec_(stmt.encode("ascii")) + assert res.status == pq.ExecStatus.COPY_OUT + assert res.binary_tuples == format + + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + for row in want: + nbytes, data = pgconn.get_copy_data(0) + assert nbytes == len(data) + assert data == row + + assert pgconn.get_copy_data(0) == (-1, b"") + + res = pgconn.get_result() + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + +def ensure_table(pgconn, tabledef, name="copy_in"): + pgconn.exec_(f"drop table if exists {name}".encode("ascii")) + pgconn.exec_(f"create table {name} ({tabledef})".encode("ascii")) diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py new file mode 100644 index 0000000..ad88d8a --- /dev/null +++ b/tests/pq/test_escaping.py @@ -0,0 +1,188 @@ +import pytest + +import psycopg +from psycopg import pq + +from ..fix_crdb import crdb_scs_off + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b"''"), + (b"hello", b"'hello'"), + (b"foo'bar", b"'foo''bar'"), + (b"foo\\bar", b" E'foo\\\\bar'"), + ], +) +def test_escape_literal(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_literal(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_literal_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b"'": b"''''", b"\\": b" E'\\\\'"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_literal(data) + exp = special.get(data) or b"'%s'" % data + assert rv == exp + + +def test_escape_literal_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b'""'), + (b"hello", b'"hello"'), + (b'foo"bar', b'"foo""bar"'), + (b"foo\\bar", b'"foo\\bar"'), + ], +) +def test_escape_identifier(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_identifier(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_identifier_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b'"': b'""""', b"\\": b'"\\"'} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_identifier(data) + exp = special.get(data) or b'"%s"' % data + assert rv == exp + + +def test_escape_identifier_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + (b"foo\\bar", b"foo\\bar"), + ], +) +def test_escape_string(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_string(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_string_1char(pgconn, scs): + esc = pq.Escaping(pgconn) + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_string(data) + exp = special.get(data) or b"%s" % data + assert rv == exp + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + # This libpq function behaves unpredictably when not passed a conn + (b"foo\\bar", (b"foo\\\\bar", b"foo\\bar")), + ], +) +def test_escape_string_noconn(data, want): + esc = pq.Escaping() + out = esc.escape_string(data) + if isinstance(want, bytes): + assert out == want + else: + assert out in want + + +def test_escape_string_badconn(pgconn): + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_string(b"hi") + + +def test_escape_string_badenc(pgconn): + res = pgconn.exec_(b"set client_encoding to 'UTF8'") + assert res.status == pq.ExecStatus.COMMAND_OK + data = "\u20ac".encode()[:-1] + esc = pq.Escaping(pgconn) + with pytest.raises(psycopg.OperationalError): + esc.escape_string(data) + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_escape_bytea(pgconn, data): + exp = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.escape_bytea(data) + assert rv == exp + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_bytea(data) + + +def test_escape_noconn(pgconn): + data = bytes(range(256)) + esc = pq.Escaping() + escdata = esc.escape_bytea(data) + res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == data + + +def test_escape_1char(pgconn): + esc = pq.Escaping(pgconn) + for c in range(256): + rv = esc.escape_bytea(bytes([c])) + exp = rb"\x%02x" % c + assert rv == exp + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_unescape_bytea(pgconn, data): + enc = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.unescape_bytea(enc) + assert rv == data + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.unescape_bytea(data) diff --git a/tests/pq/test_exec.py b/tests/pq/test_exec.py new file mode 100644 index 0000000..86c30c0 --- /dev/null +++ b/tests/pq/test_exec.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +import pytest + +import psycopg +from psycopg import pq + + +def test_exec_none(pgconn): + with pytest.raises(TypeError): + pgconn.exec_(None) + + +def test_exec(pgconn): + res = pgconn.exec_(b"select 'hel' || 'lo'") + assert res.get_value(0, 0) == b"hello" + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.exec_(b"select 'hello'") + + +def test_exec_params(pgconn): + res = pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.exec_params(b"select $1::int + $2", [b"5", b"3"]) + + +def test_exec_params_empty(pgconn): + res = pgconn.exec_params(b"select 8::int", []) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + + +def test_exec_params_types(pgconn): + res = pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700, 23]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"8" + assert res.ftype(0) == 1700 + assert res.get_value(0, 1) == b"8" + assert res.ftype(1) == 23 + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1, $2", [b"8", b"8"], [1700]) + + +def test_exec_params_nulls(pgconn): + res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"hi" + assert res.get_value(0, 1) == b"" + assert res.get_value(0, 2) is None + + +def test_exec_params_binary_in(pgconn): + val = b"foo\00bar" + res = pgconn.exec_params( + b"select length($1::bytea), length($2::bytea)", + [val, val], + param_formats=[0, 1], + ) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_exec_params_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + res = pgconn.exec_params( + b"select $1::bytea", [val], param_formats=[1], result_format=fmt + ) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +def test_prepare(pgconn): + res = pgconn.prepare(b"prep", b"select $1::int + $2::int") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"prep", [b"3", b"5"]) + assert res.get_value(0, 0) == b"8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.prepare(b"prep", b"select $1::int + $2::int") + with pytest.raises(psycopg.OperationalError): + pgconn.exec_prepared(b"prep", [b"3", b"5"]) + + +def test_prepare_types(pgconn): + res = pgconn.prepare(b"prep", b"select $1 + $2", [23, 23]) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"prep", [b"3", b"5"]) + assert res.get_value(0, 0) == b"8" + + +def test_exec_prepared_binary_in(pgconn): + val = b"foo\00bar" + res = pgconn.prepare(b"", b"select length($1::bytea), length($2::bytea)") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"", [val, val], param_formats=[0, 1]) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == b"3" + assert res.get_value(0, 1) == b"7" + + with pytest.raises(ValueError): + pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1]) + + +@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]) +def test_exec_prepared_binary_out(pgconn, fmt, out): + val = b"foo\00bar" + res = pgconn.prepare(b"", b"select $1::bytea") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == out + + +@pytest.mark.crdb_skip("server-side cursor") +def test_describe_portal(pgconn): + res = pgconn.exec_( + b""" + begin; + declare cur cursor for select * from generate_series(1,10) foo; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.describe_portal(b"cur") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + assert res.nfields == 1 + assert res.fname(0) == b"foo" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.describe_portal(b"cur") diff --git a/tests/pq/test_misc.py b/tests/pq/test_misc.py new file mode 100644 index 0000000..599758f --- /dev/null +++ b/tests/pq/test_misc.py @@ -0,0 +1,83 @@ +import pytest + +import psycopg +from psycopg import pq + + +def test_error_message(pgconn): + res = pgconn.exec_(b"wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + msg = pq.error_message(pgconn) + assert "wat" in msg + assert msg == pq.error_message(res) + primary = res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + assert primary.decode("ascii") in msg + + with pytest.raises(TypeError): + pq.error_message(None) # type: ignore[arg-type] + + res.clear() + assert pq.error_message(res) == "no details available" + pgconn.finish() + assert "NULL" in pq.error_message(pgconn) + + +@pytest.mark.crdb_skip("encoding") +def test_error_message_encoding(pgconn): + res = pgconn.exec_(b"set client_encoding to latin9") + assert res.status == pq.ExecStatus.COMMAND_OK + + res = pgconn.exec_('select 1 from "foo\u20acbar"'.encode("latin9")) + assert res.status == pq.ExecStatus.FATAL_ERROR + + msg = pq.error_message(pgconn) + assert "foo\u20acbar" in msg + + msg = pq.error_message(res) + assert "foo\ufffdbar" in msg + + msg = pq.error_message(res, encoding="latin9") + assert "foo\u20acbar" in msg + + msg = pq.error_message(res, encoding="ascii") + assert "foo\ufffdbar" in msg + + +def test_make_empty_result(pgconn): + pgconn.exec_(b"wat") + res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) + assert res.status == pq.ExecStatus.FATAL_ERROR + assert b"wat" in res.error_message + + pgconn.finish() + res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) + assert res.status == pq.ExecStatus.FATAL_ERROR + assert res.error_message == b"" + + +def test_result_set_attrs(pgconn): + res = pgconn.make_empty_result(pq.ExecStatus.COPY_OUT) + assert res.status == pq.ExecStatus.COPY_OUT + + attrs = [ + pq.PGresAttDesc(b"an_int", 0, 0, 0, 23, 0, 0), + pq.PGresAttDesc(b"a_num", 0, 0, 0, 1700, 0, 0), + pq.PGresAttDesc(b"a_bin_text", 0, 0, 1, 25, 0, 0), + ] + res.set_attributes(attrs) + assert res.nfields == 3 + + assert res.fname(0) == b"an_int" + assert res.fname(1) == b"a_num" + assert res.fname(2) == b"a_bin_text" + + assert res.fformat(0) == 0 + assert res.fformat(1) == 0 + assert res.fformat(2) == 1 + + assert res.ftype(0) == 23 + assert res.ftype(1) == 1700 + assert res.ftype(2) == 25 + + with pytest.raises(psycopg.OperationalError): + res.set_attributes(attrs) diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py new file mode 100644 index 0000000..0566151 --- /dev/null +++ b/tests/pq/test_pgconn.py @@ -0,0 +1,585 @@ +import os +import sys +import ctypes +import logging +import weakref +from select import select + +import pytest + +import psycopg +from psycopg import pq +import psycopg.generators + +from ..utils import gc_collect + + +def test_connectdb(dsn): + conn = pq.PGconn.connect(dsn.encode()) + assert conn.status == pq.ConnStatus.OK, conn.error_message + + +def test_connectdb_error(): + conn = pq.PGconn.connect(b"dbname=psycopg_test_not_for_real") + assert conn.status == pq.ConnStatus.BAD + + +@pytest.mark.parametrize("baddsn", [None, 42]) +def test_connectdb_badtype(baddsn): + with pytest.raises(TypeError): + pq.PGconn.connect(baddsn) + + +def test_connect_async(dsn): + conn = pq.PGconn.connect_start(dsn.encode()) + conn.nonblocking = 1 + while True: + assert conn.status != pq.ConnStatus.BAD + rv = conn.connect_poll() + if rv == pq.PollingStatus.OK: + break + elif rv == pq.PollingStatus.READING: + select([conn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [conn.socket], []) + else: + assert False, rv + + assert conn.status == pq.ConnStatus.OK + + conn.finish() + with pytest.raises(psycopg.OperationalError): + conn.connect_poll() + + +@pytest.mark.crdb("skip", reason="connects to any db name") +def test_connect_async_bad(dsn): + parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val} + parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real" + dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items()) + conn = pq.PGconn.connect_start(dsn) + while True: + assert conn.status != pq.ConnStatus.BAD, conn.error_message + rv = conn.connect_poll() + if rv == pq.PollingStatus.FAILED: + break + elif rv == pq.PollingStatus.READING: + select([conn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [conn.socket], []) + else: + assert False, rv + + assert conn.status == pq.ConnStatus.BAD + + +def test_finish(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + pgconn.finish() + assert pgconn.status == pq.ConnStatus.BAD + + +@pytest.mark.slow +def test_weakref(dsn): + conn = pq.PGconn.connect(dsn.encode()) + w = weakref.ref(conn) + conn.finish() + del conn + gc_collect() + assert w() is None + + +@pytest.mark.skipif( + sys.platform == "win32" + and os.environ.get("CI") == "true" + and pq.__impl__ != "python", + reason="can't figure out how to make ctypes run, don't care", +) +def test_pgconn_ptr(pgconn, libpq): + assert isinstance(pgconn.pgconn_ptr, int) + + f = libpq.PQserverVersion + f.argtypes = [ctypes.c_void_p] + f.restype = ctypes.c_int + ver = f(pgconn.pgconn_ptr) + assert ver == pgconn.server_version + + pgconn.finish() + assert pgconn.pgconn_ptr is None + + +def test_info(dsn, pgconn): + info = pgconn.info + assert len(info) > 20 + dbname = [d for d in info if d.keyword == b"dbname"][0] + assert dbname.envvar == b"PGDATABASE" + assert dbname.label == b"Database-Name" + assert dbname.dispchar == b"" + assert dbname.dispsize == 20 + + parsed = pq.Conninfo.parse(dsn.encode()) + # take the name and the user either from params or from env vars + name = [ + o.val or os.environ.get(o.envvar.decode(), "").encode() + for o in parsed + if o.keyword == b"dbname" and o.envvar + ][0] + user = [ + o.val or os.environ.get(o.envvar.decode(), "").encode() + for o in parsed + if o.keyword == b"user" and o.envvar + ][0] + assert dbname.val == (name or user) + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.info + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_reset(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD + pgconn.reset() + assert pgconn.status == pq.ConnStatus.OK + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.reset() + + assert pgconn.status == pq.ConnStatus.BAD + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_reset_async(pgconn): + assert pgconn.status == pq.ConnStatus.OK + pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())") + assert pgconn.status == pq.ConnStatus.BAD + pgconn.reset_start() + while True: + rv = pgconn.reset_poll() + if rv == pq.PollingStatus.READING: + select([pgconn.socket], [], []) + elif rv == pq.PollingStatus.WRITING: + select([], [pgconn.socket], []) + else: + break + + assert rv == pq.PollingStatus.OK + assert pgconn.status == pq.ConnStatus.OK + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.reset_start() + + with pytest.raises(psycopg.OperationalError): + pgconn.reset_poll() + + +def test_ping(dsn): + rv = pq.PGconn.ping(dsn.encode()) + assert rv == pq.Ping.OK + + rv = pq.PGconn.ping(b"port=9999") + assert rv == pq.Ping.NO_RESPONSE + + +def test_db(pgconn): + name = [o.val for o in pgconn.info if o.keyword == b"dbname"][0] + assert pgconn.db == name + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.db + + +def test_user(pgconn): + user = [o.val for o in pgconn.info if o.keyword == b"user"][0] + assert pgconn.user == user + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.user + + +def test_password(pgconn): + # not in info + assert isinstance(pgconn.password, bytes) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.password + + +def test_host(pgconn): + # might be not in info + assert isinstance(pgconn.host, bytes) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.host + + +@pytest.mark.libpq(">= 12") +def test_hostaddr(pgconn): + # not in info + assert isinstance(pgconn.hostaddr, bytes), pgconn.hostaddr + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.hostaddr + + +@pytest.mark.libpq("< 12") +def test_hostaddr_missing(pgconn): + with pytest.raises(psycopg.NotSupportedError): + pgconn.hostaddr + + +def test_port(pgconn): + port = [o.val for o in pgconn.info if o.keyword == b"port"][0] + assert pgconn.port == port + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.port + + +@pytest.mark.libpq("< 14") +def test_tty(pgconn): + tty = [o.val for o in pgconn.info if o.keyword == b"tty"][0] + assert pgconn.tty == tty + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.tty + + +@pytest.mark.libpq(">= 14") +def test_tty_noop(pgconn): + assert not any(o.val for o in pgconn.info if o.keyword == b"tty") + assert pgconn.tty == b"" + + +def test_transaction_status(pgconn): + assert pgconn.transaction_status == pq.TransactionStatus.IDLE + pgconn.exec_(b"begin") + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS + pgconn.send_query(b"select 1") + assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE + psycopg.waiting.wait(psycopg.generators.execute(pgconn), pgconn.socket) + assert pgconn.transaction_status == pq.TransactionStatus.INTRANS + pgconn.finish() + assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN + + +def test_parameter_status(dsn, monkeypatch): + monkeypatch.setenv("PGAPPNAME", "psycopg tests") + pgconn = pq.PGconn.connect(dsn.encode()) + assert pgconn.parameter_status(b"application_name") == b"psycopg tests" + assert pgconn.parameter_status(b"wat") is None + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.parameter_status(b"application_name") + + +@pytest.mark.crdb_skip("encoding") +def test_encoding(pgconn): + res = pgconn.exec_(b"set client_encoding to latin1") + assert res.status == pq.ExecStatus.COMMAND_OK + assert pgconn.parameter_status(b"client_encoding") == b"LATIN1" + + res = pgconn.exec_(b"set client_encoding to 'utf-8'") + assert res.status == pq.ExecStatus.COMMAND_OK + assert pgconn.parameter_status(b"client_encoding") == b"UTF8" + + res = pgconn.exec_(b"set client_encoding to wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + assert pgconn.parameter_status(b"client_encoding") == b"UTF8" + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.parameter_status(b"client_encoding") + + +def test_protocol_version(pgconn): + assert pgconn.protocol_version == 3 + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.protocol_version + + +def test_server_version(pgconn): + assert pgconn.server_version >= 90400 + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.server_version + + +def test_socket(pgconn): + socket = pgconn.socket + assert socket > 0 + pgconn.exec_(f"select pg_terminate_backend({pgconn.backend_pid})".encode()) + # TODO: on my box it raises OperationalError as it should. Not on Travis, + # so let's see if at least an ok value comes out of it. + try: + assert pgconn.socket == socket + except psycopg.OperationalError: + pass + + +def test_error_message(pgconn): + assert pgconn.error_message == b"" + res = pgconn.exec_(b"wat") + assert res.status == pq.ExecStatus.FATAL_ERROR + msg = pgconn.error_message + assert b"wat" in msg + pgconn.finish() + assert b"NULL" in pgconn.error_message # TODO: i10n? + + +def test_backend_pid(pgconn): + assert isinstance(pgconn.backend_pid, int) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.backend_pid + + +def test_needs_password(pgconn): + # assume connection worked so an eventually needed password wasn't missing + assert pgconn.needs_password is False + pgconn.finish() + pgconn.needs_password + + +def test_used_password(pgconn, dsn, monkeypatch): + assert isinstance(pgconn.used_password, bool) + + # Assume that if a password was passed then it was needed. + # Note that the server may still need a password passed via pgpass + # so it may be that has_password is false but still a password was + # requested by the server and passed by libpq. + info = pq.Conninfo.parse(dsn.encode()) + has_password = ( + "PGPASSWORD" in os.environ + or [i for i in info if i.keyword == b"password"][0].val is not None + ) + if has_password: + assert pgconn.used_password + + pgconn.finish() + pgconn.used_password + + +def test_ssl_in_use(pgconn): + assert isinstance(pgconn.ssl_in_use, bool) + + # If connecting via socket then ssl is not in use + if pgconn.host.startswith(b"/"): + assert not pgconn.ssl_in_use + else: + sslmode = [i.val for i in pgconn.info if i.keyword == b"sslmode"][0] + if sslmode not in (b"disable", b"allow", b"prefer"): + assert pgconn.ssl_in_use + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + pgconn.ssl_in_use + + +def test_set_single_row_mode(pgconn): + with pytest.raises(psycopg.OperationalError): + pgconn.set_single_row_mode() + + pgconn.send_query(b"select 1") + pgconn.set_single_row_mode() + + +def test_cancel(pgconn): + cancel = pgconn.get_cancel() + cancel.cancel() + cancel.cancel() + pgconn.finish() + cancel.cancel() + with pytest.raises(psycopg.OperationalError): + pgconn.get_cancel() + + +def test_cancel_free(pgconn): + cancel = pgconn.get_cancel() + cancel.free() + with pytest.raises(psycopg.OperationalError): + cancel.cancel() + cancel.free() + + +@pytest.mark.crdb_skip("notify") +def test_notify(pgconn): + assert pgconn.notifies() is None + + pgconn.exec_(b"listen foo") + pgconn.exec_(b"listen bar") + pgconn.exec_(b"notify foo, '1'") + pgconn.exec_(b"notify bar, '2'") + pgconn.exec_(b"notify foo, '3'") + + n = pgconn.notifies() + assert n.relname == b"foo" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"1" + + n = pgconn.notifies() + assert n.relname == b"bar" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"2" + + n = pgconn.notifies() + assert n.relname == b"foo" + assert n.be_pid == pgconn.backend_pid + assert n.extra == b"3" + + assert pgconn.notifies() is None + + +@pytest.mark.crdb_skip("do") +def test_notice_nohandler(pgconn): + pgconn.exec_(b"set client_min_messages to notice") + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + assert res.status == pq.ExecStatus.COMMAND_OK + + +@pytest.mark.crdb_skip("do") +def test_notice(pgconn): + msgs = [] + + def callback(res): + assert res.status == pq.ExecStatus.NONFATAL_ERROR + msgs.append(res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)) + + pgconn.exec_(b"set client_min_messages to notice") + pgconn.notice_handler = callback + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + + assert res.status == pq.ExecStatus.COMMAND_OK + assert msgs and msgs[0] == b"hello notice" + + +@pytest.mark.crdb_skip("do") +def test_notice_error(pgconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + def callback(res): + raise Exception("hello error") + + pgconn.exec_(b"set client_min_messages to notice") + pgconn.notice_handler = callback + res = pgconn.exec_( + b"do $$begin raise notice 'hello notice'; end$$ language plpgsql" + ) + + assert res.status == pq.ExecStatus.COMMAND_OK + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello error" in rec.message + + +@pytest.mark.libpq("< 14") +@pytest.mark.skipif("sys.platform != 'linux'") +def test_trace_pre14(pgconn, tmp_path): + tracef = tmp_path / "trace" + with tracef.open("w") as f: + pgconn.trace(f.fileno()) + with pytest.raises(psycopg.NotSupportedError): + pgconn.set_trace_flags(0) + pgconn.exec_(b"select 1") + pgconn.untrace() + pgconn.exec_(b"select 2") + traces = tracef.read_text() + assert "select 1" in traces + assert "select 2" not in traces + + +@pytest.mark.libpq(">= 14") +@pytest.mark.skipif("sys.platform != 'linux'") +def test_trace(pgconn, tmp_path): + tracef = tmp_path / "trace" + with tracef.open("w") as f: + pgconn.trace(f.fileno()) + pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + pgconn.exec_(b"select 1::int4 as foo") + pgconn.untrace() + pgconn.exec_(b"select 2::int4 as foo") + traces = [line.split("\t") for line in tracef.read_text().splitlines()] + assert traces == [ + ["F", "26", "Query", ' "select 1::int4 as foo"'], + ["B", "28", "RowDescription", ' 1 "foo" NNNN 0 NNNN 4 -1 0'], + ["B", "11", "DataRow", " 1 1 '1'"], + ["B", "13", "CommandComplete", ' "SELECT 1"'], + ["B", "5", "ReadyForQuery", " I"], + ] + + +@pytest.mark.skipif("sys.platform == 'linux'") +def test_trace_nonlinux(pgconn): + with pytest.raises(psycopg.NotSupportedError): + pgconn.trace(1) + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password(pgconn): + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5") + assert enc == b"md594839d658c28a357126f105b9cb14cfc" + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_scram(pgconn): + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256") + assert enc.startswith(b"SCRAM-SHA-256$") + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_badalgo(pgconn): + with pytest.raises(psycopg.OperationalError): + assert pgconn.encrypt_password(b"psycopg2", b"ashesh", b"wat") + + +@pytest.mark.libpq(">= 10") +@pytest.mark.crdb_skip("password_encryption") +def test_encrypt_password_query(pgconn): + res = pgconn.exec_(b"set password_encryption to 'md5'") + assert res.status == pq.ExecStatus.COMMAND_OK, pgconn.error_message.decode() + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh") + assert enc == b"md594839d658c28a357126f105b9cb14cfc" + + res = pgconn.exec_(b"set password_encryption to 'scram-sha-256'") + assert res.status == pq.ExecStatus.COMMAND_OK + enc = pgconn.encrypt_password(b"psycopg2", b"ashesh") + assert enc.startswith(b"SCRAM-SHA-256$") + + +@pytest.mark.libpq(">= 10") +def test_encrypt_password_closed(pgconn): + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + assert pgconn.encrypt_password(b"psycopg2", b"ashesh") + + +@pytest.mark.libpq("< 10") +def test_encrypt_password_not_supported(pgconn): + # it might even be supported, but not worth the lifetime + with pytest.raises(psycopg.NotSupportedError): + pgconn.encrypt_password(b"psycopg2", b"ashesh", b"md5") + + with pytest.raises(psycopg.NotSupportedError): + pgconn.encrypt_password(b"psycopg2", b"ashesh", b"scram-sha-256") + + +def test_str(pgconn, dsn): + assert "[IDLE]" in str(pgconn) + pgconn.finish() + assert "[BAD]" in str(pgconn) + + pgconn2 = pq.PGconn.connect_start(dsn.encode()) + assert "[" in str(pgconn2) + assert "[IDLE]" not in str(pgconn2) diff --git a/tests/pq/test_pgresult.py b/tests/pq/test_pgresult.py new file mode 100644 index 0000000..3ad818d --- /dev/null +++ b/tests/pq/test_pgresult.py @@ -0,0 +1,207 @@ +import ctypes +import pytest + +from psycopg import pq + + +@pytest.mark.parametrize( + "command, status", + [ + (b"", "EMPTY_QUERY"), + (b"select 1", "TUPLES_OK"), + (b"set timezone to utc", "COMMAND_OK"), + (b"wat", "FATAL_ERROR"), + ], +) +def test_status(pgconn, command, status): + res = pgconn.exec_(command) + assert res.status == getattr(pq.ExecStatus, status) + assert status in repr(res) + + +def test_clear(pgconn): + res = pgconn.exec_(b"select 1") + assert res.status == pq.ExecStatus.TUPLES_OK + res.clear() + assert res.status == pq.ExecStatus.FATAL_ERROR + res.clear() + assert res.status == pq.ExecStatus.FATAL_ERROR + + +def test_pgresult_ptr(pgconn, libpq): + res = pgconn.exec_(b"select 1") + assert isinstance(res.pgresult_ptr, int) + + f = libpq.PQcmdStatus + f.argtypes = [ctypes.c_void_p] + f.restype = ctypes.c_char_p + assert f(res.pgresult_ptr) == b"SELECT 1" + + res.clear() + assert res.pgresult_ptr is None + + +def test_error_message(pgconn): + res = pgconn.exec_(b"select 1") + assert res.error_message == b"" + res = pgconn.exec_(b"select wat") + assert b"wat" in res.error_message + res.clear() + assert res.error_message == b"" + + +def test_error_field(pgconn): + res = pgconn.exec_(b"select wat") + # https://github.com/cockroachdb/cockroach/issues/81794 + assert ( + res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) + or res.error_field(pq.DiagnosticField.SEVERITY) + ) == b"ERROR" + assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703" + assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + res.clear() + assert res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) is None + + +@pytest.mark.parametrize("n", range(4)) +def test_ntuples(pgconn, n): + res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")]) + assert res.ntuples == n + res.clear() + assert res.ntuples == 0 + + +def test_nfields(pgconn): + res = pgconn.exec_(b"select wat") + assert res.nfields == 0 + res = pgconn.exec_(b"select 1, 2, 3") + assert res.nfields == 3 + res.clear() + assert res.nfields == 0 + + +def test_fname(pgconn): + res = pgconn.exec_(b'select 1 as foo, 2 as "BAR"') + assert res.fname(0) == b"foo" + assert res.fname(1) == b"BAR" + assert res.fname(2) is None + assert res.fname(-1) is None + res.clear() + assert res.fname(0) is None + + +@pytest.mark.crdb("skip", reason="ftable") +def test_ftable_and_col(pgconn): + res = pgconn.exec_( + b""" + drop table if exists t1, t2; + create table t1 as select 1 as f1; + create table t2 as select 2 as f2, 3 as f3; + """ + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.exec_( + b"select f1, f3, 't1'::regclass::oid, 't2'::regclass::oid from t1, t2" + ) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + + assert res.ftable(0) == int(res.get_value(0, 2).decode("ascii")) + assert res.ftable(1) == int(res.get_value(0, 3).decode("ascii")) + assert res.ftablecol(0) == 1 + assert res.ftablecol(1) == 2 + res.clear() + assert res.ftable(0) == 0 + assert res.ftablecol(0) == 0 + + +@pytest.mark.parametrize("fmt", (0, 1)) +def test_fformat(pgconn, fmt): + res = pgconn.exec_params(b"select 1", [], result_format=fmt) + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fformat(0) == fmt + assert res.binary_tuples == fmt + res.clear() + assert res.fformat(0) == 0 + assert res.binary_tuples == 0 + + +def test_ftype(pgconn): + res = pgconn.exec_(b"select 1::int4, 1::numeric, 1::text") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.ftype(0) == 23 + assert res.ftype(1) == 1700 + assert res.ftype(2) == 25 + res.clear() + assert res.ftype(0) == 0 + + +def test_fmod(pgconn): + res = pgconn.exec_(b"select 1::int, 1::numeric(10), 1::numeric(10,2)") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fmod(0) == -1 + assert res.fmod(1) == 0xA0004 + assert res.fmod(2) == 0xA0006 + res.clear() + assert res.fmod(0) == 0 + + +def test_fsize(pgconn): + res = pgconn.exec_(b"select 1::int4, 1::bigint, 1::text") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.fsize(0) == 4 + assert res.fsize(1) == 8 + assert res.fsize(2) == -1 + res.clear() + assert res.fsize(0) == 0 + + +def test_get_value(pgconn): + res = pgconn.exec_(b"select 'a', '', NULL") + assert res.status == pq.ExecStatus.TUPLES_OK, res.error_message + assert res.get_value(0, 0) == b"a" + assert res.get_value(0, 1) == b"" + assert res.get_value(0, 2) is None + res.clear() + assert res.get_value(0, 0) is None + + +def test_nparams_types(pgconn): + res = pgconn.prepare(b"", b"select $1::int4, $2::text") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + res = pgconn.describe_prepared(b"") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + assert res.nparams == 2 + assert res.param_type(0) == 23 + assert res.param_type(1) == 25 + + res.clear() + assert res.nparams == 0 + assert res.param_type(0) == 0 + + +def test_command_status(pgconn): + res = pgconn.exec_(b"select 1") + assert res.command_status == b"SELECT 1" + res = pgconn.exec_(b"set timezone to utc") + assert res.command_status == b"SET" + res.clear() + assert res.command_status is None + + +def test_command_tuples(pgconn): + res = pgconn.exec_(b"set timezone to utf8") + assert res.command_tuples is None + res = pgconn.exec_(b"select * from generate_series(1, 10)") + assert res.command_tuples == 10 + res.clear() + assert res.command_tuples is None + + +def test_oid_value(pgconn): + res = pgconn.exec_(b"select 1") + assert res.oid_value == 0 + res.clear() + assert res.oid_value == 0 diff --git a/tests/pq/test_pipeline.py b/tests/pq/test_pipeline.py new file mode 100644 index 0000000..00cd54a --- /dev/null +++ b/tests/pq/test_pipeline.py @@ -0,0 +1,161 @@ +import pytest + +import psycopg +from psycopg import pq + + +@pytest.mark.libpq("< 14") +def test_old_libpq(pgconn): + assert pgconn.pipeline_status == 0 + with pytest.raises(psycopg.NotSupportedError): + pgconn.enter_pipeline_mode() + with pytest.raises(psycopg.NotSupportedError): + pgconn.exit_pipeline_mode() + with pytest.raises(psycopg.NotSupportedError): + pgconn.pipeline_sync() + with pytest.raises(psycopg.NotSupportedError): + pgconn.send_flush_request() + + +@pytest.mark.libpq(">= 14") +def test_work_in_progress(pgconn): + assert not pgconn.nonblocking + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"]) + with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"): + pgconn.exit_pipeline_mode() + + +@pytest.mark.libpq(">= 14") +def test_multi_pipelines(pgconn): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"], param_types=[25]) + pgconn.pipeline_sync() + pgconn.send_query_params(b"select $1", [b"2"], param_types=[25]) + pgconn.pipeline_sync() + + # result from first query + result1 = pgconn.get_result() + assert result1 is not None + assert result1.status == pq.ExecStatus.TUPLES_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # first sync result + sync_result = pgconn.get_result() + assert sync_result is not None + assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC + + # result from second query + result2 = pgconn.get_result() + assert result2 is not None + assert result2.status == pq.ExecStatus.TUPLES_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # second sync result + sync_result = pgconn.get_result() + assert sync_result is not None + assert sync_result.status == pq.ExecStatus.PIPELINE_SYNC + + # pipeline still ON + assert pgconn.pipeline_status == pq.PipelineStatus.ON + + pgconn.exit_pipeline_mode() + + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + + assert result1.get_value(0, 0) == b"1" + assert result2.get_value(0, 0) == b"2" + + +@pytest.mark.libpq(">= 14") +def test_flush_request(pgconn): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"select $1", [b"1"], param_types=[25]) + pgconn.send_flush_request() + r = pgconn.get_result() + assert r.status == pq.ExecStatus.TUPLES_OK + assert r.get_value(0, 0) == b"1" + pgconn.exit_pipeline_mode() + + +@pytest.fixture +def table(pgconn): + tablename = "pipeline" + pgconn.exec_(f"create table {tablename} (s text)".encode("ascii")) + yield tablename + pgconn.exec_(f"drop table if exists {tablename}".encode("ascii")) + + +@pytest.mark.libpq(">= 14") +def test_pipeline_abort(pgconn, table): + assert pgconn.pipeline_status == pq.PipelineStatus.OFF + pgconn.enter_pipeline_mode() + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"1"]) + pgconn.send_query_params(b"select no_such_function($1)", [b"1"]) + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"2"]) + pgconn.pipeline_sync() + pgconn.send_query_params(b"insert into pipeline values ($1)", [b"3"]) + pgconn.pipeline_sync() + + # result from first INSERT + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.COMMAND_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # error result from second query (SELECT) + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.FATAL_ERROR + + # NULL signals end of result + assert pgconn.get_result() is None + + # pipeline should be aborted, due to previous error + assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED + + # result from second INSERT, aborted due to previous error + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_ABORTED + + # NULL signals end of result + assert pgconn.get_result() is None + + # pipeline is still aborted + assert pgconn.pipeline_status == pq.PipelineStatus.ABORTED + + # sync result + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_SYNC + + # aborted flag is clear, pipeline is on again + assert pgconn.pipeline_status == pq.PipelineStatus.ON + + # result from the third INSERT + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.COMMAND_OK + + # NULL signals end of result + assert pgconn.get_result() is None + + # second sync result + r = pgconn.get_result() + assert r is not None + assert r.status == pq.ExecStatus.PIPELINE_SYNC + + # NULL signals end of result + assert pgconn.get_result() is None + + pgconn.exit_pipeline_mode() diff --git a/tests/pq/test_pq.py b/tests/pq/test_pq.py new file mode 100644 index 0000000..076c3b6 --- /dev/null +++ b/tests/pq/test_pq.py @@ -0,0 +1,57 @@ +import os + +import pytest + +import psycopg +from psycopg import pq + +from ..utils import check_libpq_version + + +def test_version(): + rv = pq.version() + assert rv > 90500 + assert rv < 200000 # you are good for a while + + +def test_build_version(): + assert pq.__build_version__ and pq.__build_version__ >= 70400 + + +@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_BUILD')") +def test_want_built_version(): + want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_BUILD"] + got = pq.__build_version__ + assert not check_libpq_version(got, want) + + +@pytest.mark.skipif("not os.environ.get('PSYCOPG_TEST_WANT_LIBPQ_IMPORT')") +def test_want_import_version(): + want = os.environ["PSYCOPG_TEST_WANT_LIBPQ_IMPORT"] + got = pq.version() + assert not check_libpq_version(got, want) + + +# Note: These tests are here because test_pipeline.py tests are all skipped +# when pipeline mode is not supported. + + +@pytest.mark.libpq(">= 14") +def test_pipeline_supported(conn): + assert psycopg.Pipeline.is_supported() + assert psycopg.AsyncPipeline.is_supported() + + with conn.pipeline(): + pass + + +@pytest.mark.libpq("< 14") +def test_pipeline_not_supported(conn): + assert not psycopg.Pipeline.is_supported() + assert not psycopg.AsyncPipeline.is_supported() + + with pytest.raises(psycopg.NotSupportedError) as exc: + with conn.pipeline(): + pass + + assert "too old" in str(exc.value) diff --git a/tests/scripts/bench-411.py b/tests/scripts/bench-411.py new file mode 100644 index 0000000..82ea451 --- /dev/null +++ b/tests/scripts/bench-411.py @@ -0,0 +1,300 @@ +import os +import sys +import time +import random +import asyncio +import logging +from enum import Enum +from typing import Any, Dict, List, Generator +from argparse import ArgumentParser, Namespace +from contextlib import contextmanager + +logger = logging.getLogger() +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", +) + + +class Driver(str, Enum): + psycopg2 = "psycopg2" + psycopg = "psycopg" + psycopg_async = "psycopg_async" + asyncpg = "asyncpg" + + +ids: List[int] = [] +data: List[Dict[str, Any]] = [] + + +def main() -> None: + + args = parse_cmdline() + + ids[:] = range(args.ntests) + data[:] = [ + dict( + id=i, + name="c%d" % i, + description="c%d" % i, + q=i * 10, + p=i * 20, + x=i * 30, + y=i * 40, + ) + for i in ids + ] + + # Must be done just on end + drop_at_the_end = args.drop + args.drop = False + + for i, name in enumerate(args.drivers): + if i == len(args.drivers) - 1: + args.drop = drop_at_the_end + + if name == Driver.psycopg2: + import psycopg2 # type: ignore + + run_psycopg2(psycopg2, args) + + elif name == Driver.psycopg: + import psycopg + + run_psycopg(psycopg, args) + + elif name == Driver.psycopg_async: + import psycopg + + if sys.platform == "win32": + if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy() + ) + + asyncio.run(run_psycopg_async(psycopg, args)) + + elif name == Driver.asyncpg: + import asyncpg # type: ignore + + asyncio.run(run_asyncpg(asyncpg, args)) + + else: + raise AssertionError(f"unknown driver: {name!r}") + + # Must be done just on start + args.create = False + + +table = """ +CREATE TABLE customer ( + id SERIAL NOT NULL, + name VARCHAR(255), + description VARCHAR(255), + q INTEGER, + p INTEGER, + x INTEGER, + y INTEGER, + z INTEGER, + PRIMARY KEY (id) +) +""" +drop = "DROP TABLE IF EXISTS customer" + +insert = """ +INSERT INTO customer (id, name, description, q, p, x, y) VALUES +(%(id)s, %(name)s, %(description)s, %(q)s, %(p)s, %(x)s, %(y)s) +""" + +select = """ +SELECT customer.id, customer.name, customer.description, customer.q, + customer.p, customer.x, customer.y, customer.z +FROM customer +WHERE customer.id = %(id)s +""" + + +@contextmanager +def time_log(message: str) -> Generator[None, None, None]: + start = time.monotonic() + yield + end = time.monotonic() + logger.info(f"Run {message} in {end-start} s") + + +def run_psycopg2(psycopg2: Any, args: Namespace) -> None: + logger.info("Running psycopg2") + + if args.create: + logger.info(f"inserting {args.ntests} test records") + with psycopg2.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + cursor.execute(table) + cursor.executemany(insert, data) + conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + with psycopg2.connect(args.dsn) as conn: + with time_log("psycopg2"): + for id_ in to_query: + with conn.cursor() as cursor: + cursor.execute(select, {"id": id_}) + cursor.fetchall() + # conn.rollback() + + if args.drop: + logger.info("dropping test records") + with psycopg2.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + conn.commit() + + +def run_psycopg(psycopg: Any, args: Namespace) -> None: + logger.info("Running psycopg sync") + + if args.create: + logger.info(f"inserting {args.ntests} test records") + with psycopg.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + cursor.execute(table) + cursor.executemany(insert, data) + conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + with psycopg.connect(args.dsn) as conn: + with time_log("psycopg"): + for id_ in to_query: + with conn.cursor() as cursor: + cursor.execute(select, {"id": id_}) + cursor.fetchall() + # conn.rollback() + + if args.drop: + logger.info("dropping test records") + with psycopg.connect(args.dsn) as conn: + with conn.cursor() as cursor: + cursor.execute(drop) + conn.commit() + + +async def run_psycopg_async(psycopg: Any, args: Namespace) -> None: + logger.info("Running psycopg async") + + conn: Any + + if args.create: + logger.info(f"inserting {args.ntests} test records") + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + async with conn.cursor() as cursor: + await cursor.execute(drop) + await cursor.execute(table) + await cursor.executemany(insert, data) + await conn.commit() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + with time_log("psycopg_async"): + for id_ in to_query: + cursor = await conn.execute(select, {"id": id_}) + await cursor.fetchall() + await cursor.close() + # await conn.rollback() + + if args.drop: + logger.info("dropping test records") + async with await psycopg.AsyncConnection.connect(args.dsn) as conn: + async with conn.cursor() as cursor: + await cursor.execute(drop) + await conn.commit() + + +async def run_asyncpg(asyncpg: Any, args: Namespace) -> None: + logger.info("Running asyncpg") + + places = dict(id="$1", name="$2", description="$3", q="$4", p="$5", x="$6", y="$7") + a_insert = insert % places + a_select = select % {"id": "$1"} + + conn: Any + + if args.create: + logger.info(f"inserting {args.ntests} test records") + conn = await asyncpg.connect(args.dsn) + async with conn.transaction(): + await conn.execute(drop) + await conn.execute(table) + await conn.executemany(a_insert, [tuple(d.values()) for d in data]) + await conn.close() + + logger.info(f"running {args.ntests} queries") + to_query = random.choices(ids, k=args.ntests) + conn = await asyncpg.connect(args.dsn) + with time_log("asyncpg"): + for id_ in to_query: + tr = conn.transaction() + await tr.start() + await conn.fetch(a_select, id_) + # await tr.rollback() + await conn.close() + + if args.drop: + logger.info("dropping test records") + conn = await asyncpg.connect(args.dsn) + async with conn.transaction(): + await conn.execute(drop) + await conn.close() + + +def parse_cmdline() -> Namespace: + parser = ArgumentParser(description=__doc__) + parser.add_argument( + "drivers", + nargs="+", + metavar="DRIVER", + type=Driver, + help=f"the drivers to test [choices: {', '.join(d.value for d in Driver)}]", + ) + + parser.add_argument( + "--ntests", + type=int, + default=10_000, + help="number of tests to perform [default: %(default)s]", + ) + + parser.add_argument( + "--dsn", + default=os.environ.get("PSYCOPG_TEST_DSN", ""), + help="database connection string" + " [default: %(default)r (from PSYCOPG_TEST_DSN env var)]", + ) + + parser.add_argument( + "--no-create", + dest="create", + action="store_false", + default="True", + help="skip data creation before tests (it must exist already)", + ) + + parser.add_argument( + "--no-drop", + dest="drop", + action="store_false", + default="True", + help="skip data drop after tests", + ) + + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/dectest.py b/tests/scripts/dectest.py new file mode 100644 index 0000000..a49f116 --- /dev/null +++ b/tests/scripts/dectest.py @@ -0,0 +1,51 @@ +""" +A quick and rough performance comparison of text vs. binary Decimal adaptation +""" +from random import randrange +from decimal import Decimal +import psycopg +from psycopg import sql + +ncols = 10 +nrows = 500000 +format = psycopg.pq.Format.BINARY +test = "copy" + + +def main() -> None: + cnn = psycopg.connect() + + cnn.execute( + sql.SQL("create table testdec ({})").format( + sql.SQL(", ").join( + [ + sql.SQL("{} numeric(10,2)").format(sql.Identifier(f"t{i}")) + for i in range(ncols) + ] + ) + ) + ) + cur = cnn.cursor() + + if test == "copy": + with cur.copy(f"copy testdec from stdin (format {format.name})") as copy: + for j in range(nrows): + copy.write_row( + [Decimal(randrange(10000000000)) / 100 for i in range(ncols)] + ) + + elif test == "insert": + ph = ["%t", "%b"][format] + cur.executemany( + "insert into testdec values (%s)" % ", ".join([ph] * ncols), + ( + [Decimal(randrange(10000000000)) / 100 for i in range(ncols)] + for j in range(nrows) + ), + ) + else: + raise Exception(f"bad test: {test}") + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py new file mode 100644 index 0000000..ec95229 --- /dev/null +++ b/tests/scripts/pipeline-demo.py @@ -0,0 +1,340 @@ +"""Pipeline mode demo + +This reproduces libpq_pipeline::pipelined_insert PostgreSQL test at +src/test/modules/libpq_pipeline/libpq_pipeline.c::test_pipelined_insert(). + +We do not fetch results explicitly (using cursor.fetch*()), this is +handled by execute() calls when pgconn socket is read-ready, which +happens when the output buffer is full. +""" +import argparse +import asyncio +import logging +from contextlib import contextmanager +from functools import partial +from typing import Any, Iterator, Optional, Sequence, Tuple + +from psycopg import AsyncConnection, Connection +from psycopg import pq, waiting +from psycopg import errors as e +from psycopg.abc import PipelineCommand +from psycopg.generators import pipeline_communicate +from psycopg.pq import Format, DiagnosticField +from psycopg._compat import Deque + +psycopg_logger = logging.getLogger("psycopg") +pipeline_logger = logging.getLogger("pipeline") +args: argparse.Namespace + + +class LoggingPGconn: + """Wrapper for PGconn that logs fetched results.""" + + def __init__(self, pgconn: pq.abc.PGconn, logger: logging.Logger): + self._pgconn = pgconn + self._logger = logger + + def log_notice(result: pq.abc.PGresult) -> None: + def get_field(field: DiagnosticField) -> Optional[str]: + value = result.error_field(field) + return value.decode("utf-8", "replace") if value else None + + logger.info( + "notice %s %s", + get_field(DiagnosticField.SEVERITY), + get_field(DiagnosticField.MESSAGE_PRIMARY), + ) + + pgconn.notice_handler = log_notice + + if args.trace: + self._trace_file = open(args.trace, "w") + pgconn.trace(self._trace_file.fileno()) + + def __del__(self) -> None: + if hasattr(self, "_trace_file"): + self._pgconn.untrace() + self._trace_file.close() + + def __getattr__(self, name: str) -> Any: + return getattr(self._pgconn, name) + + def send_query(self, command: bytes) -> None: + self._logger.warning("PQsendQuery broken in libpq 14.5") + self._pgconn.send_query(command) + self._logger.info("sent %s", command.decode()) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_params( + command, param_values, param_types, param_formats, result_format + ) + self._logger.info("sent %s", command.decode()) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_prepared( + name, param_values, param_formats, result_format + ) + self._logger.info("sent prepared '%s' with %s", name.decode(), param_values) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + self._pgconn.send_prepare(name, command, param_types) + self._logger.info("prepare %s as '%s'", command.decode(), name.decode()) + + def get_result(self) -> Optional[pq.abc.PGresult]: + r = self._pgconn.get_result() + if r is not None: + self._logger.info("got %s result", pq.ExecStatus(r.status).name) + return r + + +@contextmanager +def prepare_pipeline_demo_pq( + pgconn: LoggingPGconn, rows_to_send: int, logger: logging.Logger +) -> Iterator[Tuple[Deque[PipelineCommand], Deque[str]]]: + """Set up pipeline demo with initial queries and yield commands and + results queue for pipeline_communicate(). + """ + logger.debug("enter pipeline") + pgconn.enter_pipeline_mode() + + setup_queries = [ + ("begin", "BEGIN TRANSACTION"), + ("drop table", "DROP TABLE IF EXISTS pq_pipeline_demo"), + ( + "create table", + ( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ), + ), + ( + "prepare", + ("INSERT INTO pq_pipeline_demo(itemno, int8filler)" " VALUES ($1, $2)"), + ), + ] + + commands = Deque[PipelineCommand]() + results_queue = Deque[str]() + + for qname, query in setup_queries: + if qname == "prepare": + pgconn.send_prepare(qname.encode(), query.encode()) + else: + pgconn.send_query_params(query.encode(), None) + results_queue.append(qname) + + committed = False + synced = False + + while True: + if rows_to_send: + params = [f"{rows_to_send}".encode(), f"{1 << 62}".encode()] + commands.append(partial(pgconn.send_query_prepared, b"prepare", params)) + results_queue.append(f"row {rows_to_send}") + rows_to_send -= 1 + + elif not committed: + committed = True + commands.append(partial(pgconn.send_query_params, b"COMMIT", None)) + results_queue.append("commit") + + elif not synced: + + def sync() -> None: + pgconn.pipeline_sync() + logger.info("pipeline sync sent") + + synced = True + commands.append(sync) + results_queue.append("sync") + + else: + break + + try: + yield commands, results_queue + finally: + logger.debug("exit pipeline") + pgconn.exit_pipeline_mode() + + +def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn(Connection.connect().pgconn, logger) + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = waiting.wait( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn((await AsyncConnection.connect()).pgconn, logger) + + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = await waiting.wait_async( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +def pipeline_demo(rows_to_send: int, many: bool, logger: logging.Logger) -> None: + """Pipeline demo using sync API.""" + conn = Connection.connect() + conn.autocommit = True + conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment] + with conn.pipeline(): + with conn.transaction(): + conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + conn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = conn.cursor() + cur.executemany(query, list(params)) + else: + for p in params: + conn.execute(query, p) + + +async def pipeline_demo_async( + rows_to_send: int, many: bool, logger: logging.Logger +) -> None: + """Pipeline demo using async API.""" + aconn = await AsyncConnection.connect() + await aconn.set_autocommit(True) + aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment] + async with aconn.pipeline(): + async with aconn.transaction(): + await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + await aconn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = aconn.cursor() + await cur.executemany(query, list(params)) + else: + for p in params: + await aconn.execute(query, p) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", + dest="nrows", + metavar="ROWS", + default=10_000, + type=int, + help="number of rows to insert", + ) + parser.add_argument( + "--pq", action="store_true", help="use low-level psycopg.pq API" + ) + parser.add_argument( + "--async", dest="async_", action="store_true", help="use async API" + ) + parser.add_argument( + "--many", + action="store_true", + help="use executemany() (not applicable for --pq)", + ) + parser.add_argument("--trace", help="write trace info into TRACE file") + parser.add_argument("-l", "--log", help="log file (stderr by default)") + + global args + args = parser.parse_args() + + psycopg_logger.setLevel(logging.DEBUG) + pipeline_logger.setLevel(logging.DEBUG) + if args.log: + psycopg_logger.addHandler(logging.FileHandler(args.log)) + pipeline_logger.addHandler(logging.FileHandler(args.log)) + else: + psycopg_logger.addHandler(logging.StreamHandler()) + pipeline_logger.addHandler(logging.StreamHandler()) + + if args.pq: + if args.many: + parser.error("--many cannot be used with --pq") + if args.async_: + asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger)) + else: + pipeline_demo_pq(args.nrows, pipeline_logger) + else: + if pq.__impl__ != "python": + parser.error( + "only supported for Python implementation (set PSYCOPG_IMPL=python)" + ) + if args.async_: + asyncio.run(pipeline_demo_async(args.nrows, args.many, pipeline_logger)) + else: + pipeline_demo(args.nrows, args.many, pipeline_logger) + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/spiketest.py b/tests/scripts/spiketest.py new file mode 100644 index 0000000..2c9cc16 --- /dev/null +++ b/tests/scripts/spiketest.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +""" +Run a connection pool spike test. + +The test is inspired to the `spike analysis`__ illustrated by HikariCP + +.. __: https://github.com/brettwooldridge/HikariCP/blob/dev/documents/ + Welcome-To-The-Jungle.md + +""" +# mypy: allow-untyped-defs +# mypy: allow-untyped-calls + +import time +import threading + +import psycopg +import psycopg_pool +from psycopg.rows import Row + +import logging + + +def main() -> None: + opt = parse_cmdline() + if opt.loglevel: + loglevel = getattr(logging, opt.loglevel.upper()) + logging.basicConfig( + level=loglevel, format="%(asctime)s %(levelname)s %(message)s" + ) + + logging.getLogger("psycopg2.pool").setLevel(loglevel) + + with psycopg_pool.ConnectionPool( + opt.dsn, + min_size=opt.min_size, + max_size=opt.max_size, + connection_class=DelayedConnection, + kwargs={"conn_delay": 0.150}, + ) as pool: + pool.wait() + measurer = Measurer(pool) + + # Create and start all the thread: they will get stuck on the event + ev = threading.Event() + threads = [ + threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True) + for i in range(opt.num_clients) + ] + for t in threads: + t.start() + time.sleep(0.2) + + # Release the threads! + measurer.start(0.00025) + t0 = time.time() + ev.set() + + # Wait for the threads to finish + for t in threads: + t.join() + t1 = time.time() + measurer.stop() + + print(f"time: {(t1 - t0) * 1000} msec") + print("active,idle,total,waiting") + recs = [ + f'{m["pool_size"] - m["pool_available"]}' + f',{m["pool_available"]}' + f',{m["pool_size"]}' + f',{m["requests_waiting"]}' + for m in measurer.measures + ] + print("\n".join(recs)) + + +def worker(p, t, ev): + ev.wait() + with p.connection(): + time.sleep(t) + + +class Measurer: + def __init__(self, pool): + self.pool = pool + self.worker = None + self.stopped = False + self.measures = [] + + def start(self, interval): + self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True) + self.worker.start() + + def stop(self): + self.stopped = True + if self.worker: + self.worker.join() + self.worker = None + + def _run(self, interval): + while not self.stopped: + self.measures.append(self.pool.get_stats()) + time.sleep(interval) + + +class DelayedConnection(psycopg.Connection[Row]): + """A connection adding a delay to the connection time.""" + + @classmethod + def connect(cls, conninfo, conn_delay=0, **kwargs): + t0 = time.time() + conn = super().connect(conninfo, **kwargs) + t1 = time.time() + wait = max(0.0, conn_delay - (t1 - t0)) + if wait: + time.sleep(wait) + return conn + + +def parse_cmdline(): + from argparse import ArgumentParser + + parser = ArgumentParser(description=__doc__) + parser.add_argument("--dsn", default="", help="connection string to the database") + parser.add_argument( + "--min_size", + default=5, + type=int, + help="minimum number of connections in the pool", + ) + parser.add_argument( + "--max_size", + default=20, + type=int, + help="maximum number of connections in the pool", + ) + parser.add_argument( + "--num-clients", + default=50, + type=int, + help="number of threads making a request", + ) + parser.add_argument( + "--loglevel", + default=None, + choices=("DEBUG", "INFO", "WARNING", "ERROR"), + help="level to log at [default: no log]", + ) + + opt = parser.parse_args() + + return opt + + +if __name__ == "__main__": + main() diff --git a/tests/test_adapt.py b/tests/test_adapt.py new file mode 100644 index 0000000..2190a84 --- /dev/null +++ b/tests/test_adapt.py @@ -0,0 +1,530 @@ +import datetime as dt +from types import ModuleType +from typing import Any, List + +import pytest + +import psycopg +from psycopg import pq, sql, postgres +from psycopg import errors as e +from psycopg.adapt import Transformer, PyFormat, Dumper, Loader +from psycopg._cmodule import _psycopg +from psycopg.postgres import types as builtins, TEXT_OID +from psycopg.types.array import ListDumper, ListBinaryDumper + + +@pytest.mark.parametrize( + "data, format, result, type", + [ + (1, PyFormat.TEXT, b"1", "int2"), + ("hello", PyFormat.TEXT, b"hello", "text"), + ("hello", PyFormat.BINARY, b"hello", "text"), + ], +) +def test_dump(data, format, result, type): + t = Transformer() + dumper = t.get_dumper(data, format) + assert dumper.dump(data) == result + if type == "text" and format != PyFormat.BINARY: + assert dumper.oid == 0 + else: + assert dumper.oid == builtins[type].oid + + +@pytest.mark.parametrize( + "data, result", + [ + (1, b"1"), + ("hello", b"'hello'"), + ("he'llo", b"'he''llo'"), + (True, b"true"), + (None, b"NULL"), + ], +) +def test_quote(data, result): + t = Transformer() + dumper = t.get_dumper(data, PyFormat.TEXT) + assert dumper.quote(data) == result + + +def test_register_dumper_by_class(conn): + dumper = make_dumper("x") + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper + conn.adapters.register_dumper(MyStr, dumper) + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper + + +def test_register_dumper_by_class_name(conn): + dumper = make_dumper("x") + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper + conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper) + assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper + + +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") +def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn): + psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) + psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) + with conn_cls.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_dump_connection_ctx(conn): + conn.adapters.register_dumper(MyStr, make_bin_dumper("b")) + conn.adapters.register_dumper(MyStr, make_dumper("t")) + + cur = conn.cursor() + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellob",) + + +def test_dump_cursor_ctx(conn): + conn.adapters.register_dumper(str, make_bin_dumper("b")) + conn.adapters.register_dumper(str, make_dumper("t")) + + cur = conn.cursor() + cur.adapters.register_dumper(str, make_bin_dumper("bc")) + cur.adapters.register_dumper(str, make_dumper("tc")) + + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellotc",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellotc",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellobc",) + + cur = conn.cursor() + cur.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellot",) + cur.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellob",) + + +def test_dump_subclass(conn): + class MyString(str): + pass + + cur = conn.cursor() + cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")]) + assert cur.fetchone() == ("hello", "world") + + +def test_subclass_dumper(conn): + # This might be a C fast object: make sure that the Python code is called + from psycopg.types.string import StrDumper + + class MyStrDumper(StrDumper): + def dump(self, obj): + return (obj * 2).encode() + + conn.adapters.register_dumper(str, MyStrDumper) + assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello" + + +def test_dumper_protocol(conn): + + # This class doesn't inherit from adapt.Dumper but passes a mypy check + from .adapters_example import MyStrDumper + + conn.adapters.register_dumper(str, MyStrDumper) + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellohello" + cur = conn.execute("select %s", [["hi", "ha"]]) + assert cur.fetchone()[0] == ["hihi", "haha"] + assert sql.Literal("hello").as_string(conn) == "'qelloqello'" + + +def test_loader_protocol(conn): + + # This class doesn't inherit from adapt.Loader but passes a mypy check + from .adapters_example import MyTextLoader + + conn.adapters.register_loader("text", MyTextLoader) + cur = conn.execute("select 'hello'::text") + assert cur.fetchone()[0] == "hellohello" + cur = conn.execute("select '{hi,ha}'::text[]") + assert cur.fetchone()[0] == ["hihi", "haha"] + + +def test_subclass_loader(conn): + # This might be a C fast object: make sure that the Python code is called + from psycopg.types.string import TextLoader + + class MyTextLoader(TextLoader): + def load(self, data): + return (bytes(data) * 2).decode() + + conn.adapters.register_loader("text", MyTextLoader) + assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello" + + +@pytest.mark.parametrize( + "data, format, type, result", + [ + (b"1", pq.Format.TEXT, "int4", 1), + (b"hello", pq.Format.TEXT, "text", "hello"), + (b"hello", pq.Format.BINARY, "text", "hello"), + ], +) +def test_cast(data, format, type, result): + t = Transformer() + rv = t.get_loader(builtins[type].oid, format).load(data) + assert rv == result + + +def test_register_loader_by_oid(conn): + assert TEXT_OID == 25 + loader = make_loader("x") + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader + conn.adapters.register_loader(TEXT_OID, loader) + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader + + +def test_register_loader_by_type_name(conn): + loader = make_loader("x") + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader + conn.adapters.register_loader("text", loader) + assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader + + +@pytest.mark.crdb("skip", reason="global adapters don't affect crdb") +def test_load_global_ctx(conn_cls, dsn, global_adapters): + psycopg.adapters.register_loader("text", make_loader("gt")) + psycopg.adapters.register_loader("text", make_bin_loader("gb")) + with conn_cls.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) + + +def test_load_connection_ctx(conn): + conn.adapters.register_loader("text", make_loader("t")) + conn.adapters.register_loader("text", make_bin_loader("b")) + + r = conn.cursor(binary=False).execute("select 'hello'::text").fetchone() + assert r == ("hellot",) + r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone() + assert r == ("hellob",) + + +def test_load_cursor_ctx(conn): + conn.adapters.register_loader("text", make_loader("t")) + conn.adapters.register_loader("text", make_bin_loader("b")) + + cur = conn.cursor() + cur.adapters.register_loader("text", make_loader("tc")) + cur.adapters.register_loader("text", make_bin_loader("bc")) + + assert cur.execute("select 'hello'::text").fetchone() == ("hellotc",) + cur.format = pq.Format.BINARY + assert cur.execute("select 'hello'::text").fetchone() == ("hellobc",) + + cur = conn.cursor() + assert cur.execute("select 'hello'::text").fetchone() == ("hellot",) + cur.format = pq.Format.BINARY + assert cur.execute("select 'hello'::text").fetchone() == ("hellob",) + + +def test_cow_dumpers(conn): + conn.adapters.register_dumper(str, make_dumper("t")) + + cur1 = conn.cursor() + cur2 = conn.cursor() + cur2.adapters.register_dumper(str, make_dumper("c2")) + + r = cur1.execute("select %s::text -- 1", ["hello"]).fetchone() + assert r == ("hellot",) + r = cur2.execute("select %s::text -- 1", ["hello"]).fetchone() + assert r == ("helloc2",) + + conn.adapters.register_dumper(str, make_dumper("t1")) + r = cur1.execute("select %s::text -- 2", ["hello"]).fetchone() + assert r == ("hellot",) + r = cur2.execute("select %s::text -- 2", ["hello"]).fetchone() + assert r == ("helloc2",) + + +def test_cow_loaders(conn): + conn.adapters.register_loader("text", make_loader("t")) + + cur1 = conn.cursor() + cur2 = conn.cursor() + cur2.adapters.register_loader("text", make_loader("c2")) + + assert cur1.execute("select 'hello'::text").fetchone() == ("hellot",) + assert cur2.execute("select 'hello'::text").fetchone() == ("helloc2",) + + conn.adapters.register_loader("text", make_loader("t1")) + assert cur1.execute("select 'hello2'::text").fetchone() == ("hello2t",) + assert cur2.execute("select 'hello2'::text").fetchone() == ("hello2c2",) + + +@pytest.mark.parametrize( + "sql, obj", + [("'{hello}'::text[]", ["helloc"]), ("row('hello'::text)", ("helloc",))], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out): + cur = conn.cursor(binary=fmt_out == pq.Format.BINARY) + if fmt_out == pq.Format.TEXT: + cur.adapters.register_loader("text", make_loader("c")) + else: + cur.adapters.register_loader("text", make_bin_loader("c")) + + cur.execute(f"select {sql}") + res = cur.fetchone()[0] + assert res == obj + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_list_dumper(conn, fmt_out): + t = Transformer(conn) + fmt_in = PyFormat.from_pq(fmt_out) + dint = t.get_dumper([0], fmt_in) + assert isinstance(dint, (ListDumper, ListBinaryDumper)) + assert dint.oid == builtins["int2"].array_oid + assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid + + dstr = t.get_dumper([""], fmt_in) + assert dstr is not dint + + assert t.get_dumper([1], fmt_in) is dint + assert t.get_dumper([None, [1]], fmt_in) is dint + + dempty = t.get_dumper([], fmt_in) + assert t.get_dumper([None, [None]], fmt_in) is dempty + assert dempty.oid == 0 + assert dempty.dump([]) == b"{}" + + L: List[List[Any]] = [] + L.append(L) + with pytest.raises(psycopg.DataError): + assert t.get_dumper(L, fmt_in) + + +@pytest.mark.crdb("skip", reason="test in crdb test suite") +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == 0 + assert dstr.sub_dumper and dstr.sub_dumper.oid == 0 + + +def test_str_list_dumper_binary(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.BINARY) + assert isinstance(dstr, ListBinaryDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +def test_last_dumper_registered_ctx(conn): + cur = conn.cursor() + + bd = make_bin_dumper("b") + cur.adapters.register_dumper(str, bd) + td = make_dumper("t") + cur.adapters.register_dumper(str, td) + + assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellot" + assert cur.execute("select %t", ["hello"]).fetchone()[0] == "hellot" + assert cur.execute("select %b", ["hello"]).fetchone()[0] == "hellob" + + cur.adapters.register_dumper(str, bd) + assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellob" + + +@pytest.mark.parametrize("fmt_in", [PyFormat.TEXT, PyFormat.BINARY]) +def test_none_type_argument(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table none_args (id serial primary key, num integer)") + cur.execute("insert into none_args (num) values (%s) returning id", (None,)) + assert cur.fetchone()[0] + + +@pytest.mark.crdb("skip", reason="test in crdb test suite") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as unknown oid to libpq. This is because + # unknown is more easily cast by postgres to different types (see jsonb + # later). + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + if fmt_in != PyFormat.BINARY: + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + else: + # Binary types cannot be passed as unknown oids. + with pytest.raises(e.DatatypeMismatch): + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_no_cast_needed(conn, fmt_in): + # Verify that there is no need of cast in certain common scenario + cur = conn.execute(f"select '2021-01-01'::date + %{fmt_in.value}", [3]) + assert cur.fetchone()[0] == dt.date(2021, 1, 4) + + cur = conn.execute(f"select '[10, 20, 30]'::jsonb -> %{fmt_in.value}", [1]) + assert cur.fetchone()[0] == 20 + + +@pytest.mark.slow +@pytest.mark.skipif(_psycopg is None, reason="C module test") +def test_optimised_adapters(): + + # All the optimised adapters available + c_adapters = {} + for n in dir(_psycopg): + if n.startswith("_") or n in ("CDumper", "CLoader"): + continue + obj = getattr(_psycopg, n) + if not isinstance(obj, type): + continue + if not issubclass( + obj, + (_psycopg.CDumper, _psycopg.CLoader), # type: ignore[attr-defined] + ): + continue + c_adapters[n] = obj + + # All the registered adapters + reg_adapters = set() + adapters = list(postgres.adapters._dumpers.values()) + postgres.adapters._loaders + assert len(adapters) == 5 + for m in adapters: + reg_adapters |= set(m.values()) + + # Check that the registered adapters are the optimised one + i = 0 + for cls in reg_adapters: + if cls.__name__ in c_adapters: + assert cls is c_adapters[cls.__name__] + i += 1 + + assert i >= 10 + + # Check that every optimised adapter is the optimised version of a Py one + for n in dir(psycopg.types): + mod = getattr(psycopg.types, n) + if not isinstance(mod, ModuleType): + continue + for n1 in dir(mod): + obj = getattr(mod, n1) + if not isinstance(obj, type): + continue + if not issubclass(obj, (Dumper, Loader)): + continue + c_adapters.pop(obj.__name__, None) + + assert not c_adapters + + +def test_dumper_init_error(conn): + class BadDumper(Dumper): + def __init__(self, cls, context): + super().__init__(cls, context) + 1 / 0 + + def dump(self, obj): + return obj.encode() + + cur = conn.cursor() + cur.adapters.register_dumper(str, BadDumper) + with pytest.raises(ZeroDivisionError): + cur.execute("select %s::text", ["hi"]) + + +def test_loader_init_error(conn): + class BadLoader(Loader): + def __init__(self, oid, context): + super().__init__(oid, context) + 1 / 0 + + def load(self, data): + return data.decode() + + cur = conn.cursor() + cur.adapters.register_loader("text", BadLoader) + with pytest.raises(ZeroDivisionError): + cur.execute("select 'hi'::text") + assert cur.fetchone() == ("hi",) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_random(conn, faker, fmt, fmt_out): + faker.format = fmt + faker.choose_schema(ncols=20) + faker.make_records(50) + + with conn.cursor(binary=fmt_out) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +class MyStr(str): + pass + + +def make_dumper(suffix): + """Create a test dumper appending a suffix to the bytes representation.""" + + class TestDumper(Dumper): + oid = TEXT_OID + format = pq.Format.TEXT + + def dump(self, s): + return (s + suffix).encode("ascii") + + return TestDumper + + +def make_bin_dumper(suffix): + cls = make_dumper(suffix) + cls.format = pq.Format.BINARY + return cls + + +def make_loader(suffix): + """Create a test loader appending a suffix to the data returned.""" + + class TestLoader(Loader): + format = pq.Format.TEXT + + def load(self, b): + return bytes(b).decode("ascii") + suffix + + return TestLoader + + +def make_bin_loader(suffix): + cls = make_loader(suffix) + cls.format = pq.Format.BINARY + return cls diff --git a/tests/test_client_cursor.py b/tests/test_client_cursor.py new file mode 100644 index 0000000..b355604 --- /dev/null +++ b/tests/test_client_cursor.py @@ -0,0 +1,855 @@ +import pickle +import weakref +import datetime as dt +from typing import List + +import pytest + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins + +from .utils import gc_collect, gc_count +from .test_cursor import my_row_factory +from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision + + +@pytest.fixture +def conn(conn): + conn.cursor_factory = psycopg.ClientCursor + return conn + + +def test_init(conn): + cur = psycopg.ClientCursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.ClientCursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.ClientCursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_from_cursor_factory(conn_cls, dsn): + with conn_cls.connect(dsn, cursor_factory=psycopg.ClientCursor) as conn: + cur = conn.cursor() + assert type(cur) is psycopg.ClientCursor + + cur.execute("select %s", (1,)) + assert cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +def test_close(conn): + cur = conn.cursor() + assert not cur.closed + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.execute("select 'foo'") + + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchall() + + +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +def test_weakref(conn): + cur = conn.cursor() + w = weakref.ref(cur) + cur.close() + del cur + gc_collect() + assert w() is None + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_statusmessage(conn): + cur = conn.cursor() + assert cur.statusmessage is None + + cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + cur.execute("wat") + assert cur.statusmessage is None + + +def test_execute_sql(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {value}").format(value="hello")) + assert cur.fetchone() == ("hello",) + + +def test_execute_many_results(conn): + cur = conn.cursor() + assert cur.nextset() is None + + rv = cur.execute("select %s; select generate_series(1,%s)", ("foo", 3)) + assert rv is cur + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.nextset() is None + + cur.close() + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +def test_execute_empty_query(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +def test_execute_type_change(conn): + # issue #112 + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.execute(sql, (1,)) + cur.execute(sql, (100_000,)) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +def test_executemany_type_change(conn): + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.executemany(sql, [(1,), (100_000,)]) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +def test_execute_copy(conn, query): + cur = conn.cursor() + cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + cur.execute(query) + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = cur.fetchone() + assert row == (1, "foo", None) + row = cur.fetchone() + assert row is None + + +def test_binary_cursor_execute(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None]) + + +def test_execute_binary(conn): + with pytest.raises(psycopg.NotSupportedError): + cur = conn.cursor() + cur.execute("select %s, %s", [1, None], binary=True) + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_query_encode(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_query_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +def test_executemany_no_data(conn, execmany): + cur = conn.cursor() + cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +def test_executemany_returning(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_returning_discard(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + assert cur.nextset() is None + + +def test_executemany_no_result(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +def test_executemany_rowcount_no_hit(conn, execmany): + cur = conn.cursor() + cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)]) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + # This fails, but only because we try to copy in pipeline mode, + # crashing the connection. Which would be even fine, but with + # the async cursor it's worse... See test_client_cursor_async.py. + # "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_executemany_null_first(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table testmany (a bigint, b bigint)") + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +def test_rowcount(conn): + cur = conn.cursor() + + cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)") + assert cur.rowcount == 42 + + +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + +def test_row_factory(conn): + cur = conn.cursor(row_factory=my_row_factory) + + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" + + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert cur.fetchall() == [["Yy", "Zz"]] + + cur.scroll(-1) + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} + + +def test_row_factory_none(conn): + cur = conn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + r = cur.execute("select 1 as a, 2 as b").fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +def test_bad_row_factory(conn): + def broken_factory(cur): + 1 / 0 + + cur = conn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = conn.cursor(row_factory=broken_maker) + cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + cur.fetchone() + + +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + +def test_query_params_execute(conn): + cur = conn.cursor() + assert cur._query is None + + cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select %(x)s", {"x": 1}, (1,)), + ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)), + ("select %(x)s, %(x)s", {"x": 1}, (1, 1)), + ], +) +def test_query_params_named(conn, query, params, want): + cur = conn.cursor() + cur.execute(query, params) + rec = cur.fetchone() + assert rec == want + + +def test_query_params_executemany(conn): + cur = conn.cursor() + + cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_stream(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +class TestColumn: + def test_description_attribs(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + assert len(curs.description) == 3 + for c in curs.description: + len(c) == 7 # DBAPI happy + for i, a in enumerate( + """ + name type_code display_size internal_size precision scale null_ok + """.split() + ): + assert c[i] == getattr(c, a) + + # Won't fill them up + assert c.null_ok is None + + c = curs.description[0] + assert c.name == "pi" + assert c.type_code == builtins["numeric"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision == 10 + assert c.scale == 2 + + c = curs.description[1] + assert c.name == "hi" + assert c.type_code == builtins["text"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision is None + assert c.scale is None + + c = curs.description[2] + assert c.name == "now" + assert c.type_code == builtins["date"].oid + assert c.display_size is None + if is_crdb(conn): + assert c.internal_size == 16 + else: + assert c.internal_size == 4 + assert c.precision is None + assert c.scale is None + + def test_description_slice(self, conn): + curs = conn.cursor() + curs.execute("select 1::int as a") + curs.description[0][0:2] == ("a", 23) + + @pytest.mark.parametrize( + "type, precision, scale, dsize, isize", + [ + ("text", None, None, None, None), + ("varchar", None, None, None, None), + ("varchar(42)", None, None, 42, None), + ("int4", None, None, None, 4), + ("numeric", None, None, None, None), + ("numeric(10)", 10, 0, None, None), + ("numeric(10, 3)", 10, 3, None, None), + ("time", None, None, None, 8), + crdb_time_precision("time(4)", 4, None, None, 8), + crdb_time_precision("time(10)", 6, None, None, 8), + ], + ) + def test_details(self, conn, type, precision, scale, dsize, isize): + cur = conn.cursor() + cur.execute(f"select null::{type}") + col = cur.description[0] + repr(col) + assert col.precision == precision + assert col.scale == scale + assert col.display_size == dsize + assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled] + + @pytest.mark.crdb_skip("no col query") + def test_no_col_query(self, conn): + cur = conn.execute("select") + assert cur.description == [] + assert cur.fetchall() == [()] + + def test_description_closed_connection(self, conn): + # If we have reasons to break this test we will (e.g. we really need + # the connection). In #172 it fails just by accident. + cur = conn.execute("select 1::int4 as foo") + conn.close() + assert len(cur.description) == 1 + col = cur.description[0] + assert col.name == "foo" + assert col.type_code == 23 + + def test_name_not_a_name(self, conn): + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone() + assert res == "x" + assert cur.description[0].name == "foo-bar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_name_encode(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone() + assert res == "x" + assert cur.description[0].name == "\u20ac" + + +def test_str(conn): + cur = conn.cursor() + assert "psycopg.ClientCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_leak(conn_cls, dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + def work(): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): + with psycopg.ClientCursor(conn, row_factory=row_factory) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + cur.fetchall() + elif fetch == "iter": + for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select 'hello'", (), "select 'hello'"), + ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"), + ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"), + ("select %%", (), "select %%"), + ("select %%, %s", (["a"],), "select %, 'a'"), + ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"), + ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"), + ], +) +def test_mogrify(conn, query, params, want): + cur = conn.cursor() + got = cur.mogrify(query, *params) + assert got == want + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_mogrify_encoding(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + q = conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + assert q == "select '\u20ac'" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_mogrify_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + with pytest.raises(UnicodeEncodeError): + conn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + + +@pytest.mark.pipeline +def test_message_0x33(conn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + conn.autocommit = True + with conn.pipeline(): + cur = conn.execute("select 'test'") + assert cur.fetchone() == ("test",) + + assert not notices diff --git a/tests/test_client_cursor_async.py b/tests/test_client_cursor_async.py new file mode 100644 index 0000000..0cf8ec6 --- /dev/null +++ b/tests/test_client_cursor_async.py @@ -0,0 +1,727 @@ +import pytest +import weakref +import datetime as dt +from typing import List + +import psycopg +from psycopg import sql, rows +from psycopg.adapt import PyFormat + +from .utils import alist, gc_collect, gc_count +from .test_cursor import my_row_factory +from .test_cursor import execmany, _execmany # noqa: F401 +from .fix_crdb import crdb_encoding + +execmany = execmany # avoid F811 underneath +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def aconn(aconn): + aconn.cursor_factory = psycopg.AsyncClientCursor + return aconn + + +async def test_init(aconn): + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncClientCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncClientCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_from_cursor_factory(aconn_cls, dsn): + async with await aconn_cls.connect( + dsn, cursor_factory=psycopg.AsyncClientCursor + ) as aconn: + cur = aconn.cursor() + assert type(cur) is psycopg.AsyncClientCursor + + await cur.execute("select %s", (1,)) + assert await cur.fetchone() == (1,) + assert cur._query + assert cur._query.query == b"select 1" + + +async def test_close(aconn): + cur = aconn.cursor() + assert not cur.closed + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.execute("select 'foo'") + + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchall() + + +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +async def test_weakref(aconn): + cur = aconn.cursor() + w = weakref.ref(cur) + await cur.close() + del cur + gc_collect() + assert w() is None + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_statusmessage(aconn): + cur = aconn.cursor() + assert cur.statusmessage is None + + await cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + await cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + await cur.execute("wat") + assert cur.statusmessage is None + + +async def test_execute_sql(aconn): + cur = aconn.cursor() + await cur.execute(sql.SQL("select {value}").format(value="hello")) + assert await cur.fetchone() == ("hello",) + + +async def test_execute_many_results(aconn): + cur = aconn.cursor() + assert cur.nextset() is None + + rv = await cur.execute("select %s; select generate_series(1,%s)", ("foo", 3)) + assert rv is cur + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 + assert cur.nextset() is None + + await cur.close() + assert cur.nextset() is None + + +async def test_execute_sequence(aconn): + cur = aconn.cursor() + rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +async def test_execute_empty_query(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + + +async def test_execute_type_change(aconn): + # issue #112 + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.execute(sql, (1,)) + await cur.execute(sql, (100_000,)) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +async def test_executemany_type_change(aconn): + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.executemany(sql, [(1,), (100_000,)]) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +async def test_execute_copy(aconn, query): + cur = aconn.cursor() + await cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + await cur.execute(query) + + +async def test_fetchone(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = await cur.fetchone() + assert row == (1, "foo", None) + row = await cur.fetchone() + assert row is None + + +async def test_binary_cursor_execute(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None]) + + +async def test_execute_binary(aconn): + with pytest.raises(psycopg.NotSupportedError): + cur = aconn.cursor() + await cur.execute("select %s, %s", [1, None], binary=True) + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_query_encode(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + await cur.execute("select '\u20ac'") + (res,) = await cur.fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_query_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + await cur.execute("select '\u20ac'") + + +async def test_executemany(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] + + +async def test_executemany_name(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] + + +async def test_executemany_no_data(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +async def test_executemany_rowcount(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +async def test_executemany_returning(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_returning_discard(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + assert cur.nextset() is None + + +async def test_executemany_no_result(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +async def test_executemany_rowcount_no_hit(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + await cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + await cur.executemany( + "delete from execmany where id = %s returning num", [(-1,), (-2,)] + ) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + # This fails because we end up trying to copy in pipeline mode. + # However, sometimes (and pretty regularly if we enable pgconn.trace()) + # something goes in a loop and only terminates by OOM. Strace shows + # an allocation loop. I think it's in the libpq. + # "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +async def test_executemany_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +async def test_executemany_null_first(aconn, fmt_in): + cur = aconn.cursor() + await cur.execute("create table testmany (a bigint, b bigint)") + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + +async def test_rownumber(aconn): + cur = aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + async for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False + + +async def test_row_factory(aconn): + cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + + await cur.scroll(-1) + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} + + +async def test_row_factory_none(aconn): + cur = aconn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + await cur.execute("select 1 as a, 2 as b") + r = await cur.fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +async def test_bad_row_factory(aconn): + def broken_factory(cur): + 1 / 0 + + cur = aconn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + await cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = aconn.cursor(row_factory=broken_maker) + await cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + await cur.fetchone() + + +async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + +async def test_query_params_execute(aconn): + cur = aconn.cursor() + assert cur._query is None + + await cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select 1, NULL::text" + assert cur._query.params == (b"1", b"NULL") + + await cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + await cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select 'wat'::int" + assert cur._query.params == (b"'wat'",) + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select %(x)s", {"x": 1}, (1,)), + ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)), + ("select %(x)s, %(x)s", {"x": 1}, (1, 1)), + ], +) +async def test_query_params_named(aconn, query, params, want): + cur = aconn.cursor() + await cur.execute(query, params) + rec = await cur.fetchone() + assert rec == want + + +async def test_query_params_executemany(aconn): + cur = aconn.cursor() + + await cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select 3, 4" + assert cur._query.params == (b"3", b"4") + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_stream(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_str(aconn): + cur = aconn.cursor() + assert "psycopg.AsyncClientCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + await cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + await cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_leak(aconn_cls, dsn, faker, fetch, row_factory): + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + async def work(): + async with await aconn_cls.connect(dsn) as conn, conn.transaction( + force_rollback=True + ): + async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + await cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = await cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = await cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + await cur.fetchall() + elif fetch == "iter": + async for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.parametrize( + "query, params, want", + [ + ("select 'hello'", (), "select 'hello'"), + ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"), + ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"), + ("select %%", (), "select %%"), + ("select %%, %s", (["a"],), "select %, 'a'"), + ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"), + ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"), + ], +) +async def test_mogrify(aconn, query, params, want): + cur = aconn.cursor() + got = cur.mogrify(query, *params) + assert got == want + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_mogrify_encoding(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + assert q == "select '\u20ac'" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_mogrify_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + with pytest.raises(UnicodeEncodeError): + aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"}) + + +@pytest.mark.pipeline +async def test_message_0x33(aconn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + await aconn.set_autocommit(True) + async with aconn.pipeline(): + cur = await aconn.execute("select 'test'") + assert (await cur.fetchone()) == ("test",) + + assert not notices diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..eec24f1 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,327 @@ +""" +Tests dealing with concurrency issues. +""" + +import os +import sys +import time +import queue +import signal +import threading +import multiprocessing +import subprocess as sp +from typing import List + +import pytest + +import psycopg +from psycopg import errors as e + + +@pytest.mark.slow +def test_concurrent_execution(conn_cls, dsn): + def worker(): + cnn = conn_cls.connect(dsn) + cur = cnn.cursor() + cur.execute("select pg_sleep(0.5)") + cur.close() + cnn.close() + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t0 = time.time() + t1.start() + t2.start() + t1.join() + t2.join() + assert time.time() - t0 < 0.8, "something broken in concurrency" + + +@pytest.mark.slow +def test_commit_concurrency(conn): + # Check the condition reported in psycopg2#103 + # Because of bad status check, we commit even when a commit is already on + # its way. We can detect this condition by the warnings. + notices = queue.Queue() # type: ignore[var-annotated] + conn.add_notice_handler(lambda diag: notices.put(diag.message_primary)) + stop = False + + def committer(): + nonlocal stop + while not stop: + conn.commit() + + cur = conn.cursor() + t1 = threading.Thread(target=committer) + t1.start() + for i in range(1000): + cur.execute("select %s;", (i,)) + conn.commit() + + # Stop the committer thread + stop = True + t1.join() + + assert notices.empty(), "%d notices raised" % notices.qsize() + + +@pytest.mark.slow +@pytest.mark.subprocess +def test_multiprocess_close(dsn, tmpdir): + # Check the problem reported in psycopg2#829 + # Subprocess gcs the copy of the fd after fork so it closes connection. + module = f"""\ +import time +import psycopg + +def thread(): + conn = psycopg.connect({dsn!r}) + curs = conn.cursor() + for i in range(10): + curs.execute("select 1") + time.sleep(0.1) + +def process(): + time.sleep(0.2) +""" + + script = """\ +import time +import threading +import multiprocessing +import mptest + +t = threading.Thread(target=mptest.thread, name='mythread') +t.start() +time.sleep(0.2) +multiprocessing.Process(target=mptest.process, name='myprocess').start() +t.join() +""" + + with (tmpdir / "mptest.py").open("w") as f: + f.write(module) + env = dict(os.environ) + env["PYTHONPATH"] = str(tmpdir + os.pathsep + env.get("PYTHONPATH", "")) + out = sp.check_output( + [sys.executable, "-c", script], stderr=sp.STDOUT, env=env + ).decode("utf8", "replace") + assert out == "", out.strip().splitlines()[-1] + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("notify") +def test_notifies(conn_cls, conn, dsn): + nconn = conn_cls.connect(dsn, autocommit=True) + npid = nconn.pgconn.backend_pid + + def notifier(): + time.sleep(0.25) + nconn.cursor().execute("notify foo, '1'") + time.sleep(0.25) + nconn.cursor().execute("notify foo, '2'") + nconn.close() + + conn.autocommit = True + conn.cursor().execute("listen foo") + + t0 = time.time() + t = threading.Thread(target=notifier) + t.start() + + ns = [] + gen = conn.notifies() + for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + gen.close() + + assert len(ns) == 2 + + n, t1 = ns[0] + assert isinstance(n, psycopg.Notify) + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + t.join() + + +def canceller(conn, errors): + try: + time.sleep(0.5) + conn.cancel() + except Exception as exc: + errors.append(exc) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +def test_cancel(conn): + errors: List[Exception] = [] + + cur = conn.cursor() + t = threading.Thread(target=canceller, args=(conn, errors)) + t0 = time.time() + t.start() + + with pytest.raises(e.QueryCanceled): + cur.execute("select pg_sleep(2)") + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 + + t.join() + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +def test_cancel_stream(conn): + errors: List[Exception] = [] + + cur = conn.cursor() + t = threading.Thread(target=canceller, args=(conn, errors)) + t0 = time.time() + t.start() + + with pytest.raises(e.QueryCanceled): + for row in cur.stream("select pg_sleep(2)"): + pass + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 + + t.join() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +@pytest.mark.slow +def test_identify_closure(conn_cls, dsn): + def closer(): + time.sleep(0.2) + conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]) + + conn = conn_cls.connect(dsn) + conn2 = conn_cls.connect(dsn) + try: + t = threading.Thread(target=closer) + t.start() + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_sleep(1.0)") + t1 = time.time() + assert 0.2 < t1 - t0 < 0.4 + finally: + t.join() + finally: + conn.close() + conn2.close() + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + sys.platform == "win32", reason="don't know how to Ctrl-C on Windows" +) +@pytest.mark.crdb_skip("cancel") +def test_ctrl_c(dsn): + if sys.platform == "win32": + sig = int(signal.CTRL_C_EVENT) + # Or pytest will receive the Ctrl-C too + creationflags = sp.CREATE_NEW_PROCESS_GROUP + else: + sig = int(signal.SIGINT) + creationflags = 0 + + script = f"""\ +import os +import time +import psycopg +from threading import Thread + +def tired_of_life(): + time.sleep(1) + os.kill(os.getpid(), {sig!r}) + +t = Thread(target=tired_of_life, daemon=True) +t.start() + +with psycopg.connect({dsn!r}) as conn: + cur = conn.cursor() + ctrl_c = False + try: + cur.execute("select pg_sleep(2)") + except KeyboardInterrupt: + ctrl_c = True + + assert ctrl_c, "ctrl-c not received" + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + ), f"transaction status: {{conn.info.transaction_status!r}}" + + conn.rollback() + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE + ), f"transaction status: {{conn.info.transaction_status!r}}" + + cur.execute("select 1") + assert cur.fetchone() == (1,) +""" + t0 = time.time() + proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags) + proc.communicate() + t = time.time() - t0 + assert proc.returncode == 0 + assert 1 < t < 2 + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + multiprocessing.get_all_start_methods()[0] != "fork", + reason="problematic behavior only exhibited via fork", +) +def test_segfault_on_fork_close(dsn): + # https://github.com/psycopg/psycopg/issues/300 + script = f"""\ +import gc +import psycopg +from multiprocessing import Pool + +def test(arg): + conn1 = psycopg.connect({dsn!r}) + conn1.close() + conn1 = None + gc.collect() + return 1 + +if __name__ == '__main__': + conn = psycopg.connect({dsn!r}) + with Pool(2) as p: + pool_result = p.map_async(test, [1, 2]) + pool_result.wait(timeout=5) + if pool_result.ready(): + print(pool_result.get(timeout=1)) +""" + env = dict(os.environ) + env["PYTHONFAULTHANDLER"] = "1" + out = sp.check_output([sys.executable, "-s", "-c", script], env=env) + assert out.decode().rstrip() == "[1, 1]" diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py new file mode 100644 index 0000000..29b08cf --- /dev/null +++ b/tests/test_concurrency_async.py @@ -0,0 +1,242 @@ +import sys +import time +import signal +import asyncio +import subprocess as sp +from asyncio.queues import Queue +from typing import List, Tuple + +import pytest + +import psycopg +from psycopg import errors as e +from psycopg._compat import create_task + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.slow +async def test_commit_concurrency(aconn): + # Check the condition reported in psycopg2#103 + # Because of bad status check, we commit even when a commit is already on + # its way. We can detect this condition by the warnings. + notices = Queue() # type: ignore[var-annotated] + aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary)) + stop = False + + async def committer(): + nonlocal stop + while not stop: + await aconn.commit() + await asyncio.sleep(0) # Allow the other worker to work + + async def runner(): + nonlocal stop + cur = aconn.cursor() + for i in range(1000): + await cur.execute("select %s;", (i,)) + await aconn.commit() + + # Stop the committer thread + stop = True + + await asyncio.gather(committer(), runner()) + assert notices.empty(), "%d notices raised" % notices.qsize() + + +@pytest.mark.slow +async def test_concurrent_execution(aconn_cls, dsn): + async def worker(): + cnn = await aconn_cls.connect(dsn) + cur = cnn.cursor() + await cur.execute("select pg_sleep(0.5)") + await cur.close() + await cnn.close() + + workers = [worker(), worker()] + t0 = time.time() + await asyncio.gather(*workers) + assert time.time() - t0 < 0.8, "something broken in concurrency" + + +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.crdb_skip("notify") +async def test_notifies(aconn_cls, aconn, dsn): + nconn = await aconn_cls.connect(dsn, autocommit=True) + npid = nconn.pgconn.backend_pid + + async def notifier(): + cur = nconn.cursor() + await asyncio.sleep(0.25) + await cur.execute("notify foo, '1'") + await asyncio.sleep(0.25) + await cur.execute("notify foo, '2'") + await nconn.close() + + async def receiver(): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + gen = aconn.notifies() + async for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + await gen.aclose() + + ns: List[Tuple[psycopg.Notify, float]] = [] + t0 = time.time() + workers = [notifier(), receiver()] + await asyncio.gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +async def canceller(aconn, errors): + try: + await asyncio.sleep(0.5) + aconn.cancel() + except Exception as exc: + errors.append(exc) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +async def test_cancel(aconn): + async def worker(): + cur = aconn.cursor() + with pytest.raises(e.QueryCanceled): + await cur.execute("select pg_sleep(2)") + + errors: List[Exception] = [] + workers = [worker(), canceller(aconn, errors)] + + t0 = time.time() + await asyncio.gather(*workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("cancel") +async def test_cancel_stream(aconn): + async def worker(): + cur = aconn.cursor() + with pytest.raises(e.QueryCanceled): + async for row in cur.stream("select pg_sleep(2)"): + pass + + errors: List[Exception] = [] + workers = [worker(), canceller(aconn, errors)] + + t0 = time.time() + await asyncio.gather(*workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + + +@pytest.mark.slow +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_identify_closure(aconn_cls, dsn): + async def closer(): + await asyncio.sleep(0.2) + await conn2.execute( + "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid] + ) + + aconn = await aconn_cls.connect(dsn) + conn2 = await aconn_cls.connect(dsn) + try: + t = create_task(closer()) + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + await aconn.execute("select pg_sleep(1.0)") + t1 = time.time() + assert 0.2 < t1 - t0 < 0.4 + finally: + await asyncio.gather(t) + finally: + await aconn.close() + await conn2.close() + + +@pytest.mark.slow +@pytest.mark.subprocess +@pytest.mark.skipif( + sys.platform == "win32", reason="don't know how to Ctrl-C on Windows" +) +@pytest.mark.crdb_skip("cancel") +async def test_ctrl_c(dsn): + script = f"""\ +import signal +import asyncio +import psycopg + +async def main(): + ctrl_c = False + loop = asyncio.get_event_loop() + async with await psycopg.AsyncConnection.connect({dsn!r}) as conn: + loop.add_signal_handler(signal.SIGINT, conn.cancel) + cur = conn.cursor() + try: + await cur.execute("select pg_sleep(2)") + except psycopg.errors.QueryCanceled: + ctrl_c = True + + assert ctrl_c, "ctrl-c not received" + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + ), f"transaction status: {{conn.info.transaction_status!r}}" + + await conn.rollback() + assert ( + conn.info.transaction_status == psycopg.pq.TransactionStatus.IDLE + ), f"transaction status: {{conn.info.transaction_status!r}}" + + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + +asyncio.run(main()) +""" + if sys.platform == "win32": + creationflags = sp.CREATE_NEW_PROCESS_GROUP + sig = signal.CTRL_C_EVENT + else: + creationflags = 0 + sig = signal.SIGINT + + proc = sp.Popen([sys.executable, "-s", "-c", script], creationflags=creationflags) + with pytest.raises(sp.TimeoutExpired): + outs, errs = proc.communicate(timeout=1) + + proc.send_signal(sig) + proc.communicate() + assert proc.returncode == 0 diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..57c6c78 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,790 @@ +import time +import pytest +import logging +import weakref +from typing import Any, List +from dataclasses import dataclass + +import psycopg +from psycopg import Notify, errors as e +from psycopg.rows import tuple_row +from psycopg.conninfo import conninfo_to_dict, make_conninfo + +from .utils import gc_collect +from .test_cursor import my_row_factory +from .test_adapt import make_bin_dumper, make_dumper + + +def test_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + conn.close() + + +def test_connect_str_subclass(conn_cls, dsn): + class MyString(str): + pass + + conn = conn_cls.connect(MyString(dsn)) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + conn.close() + + +def test_connect_bad(conn_cls): + with pytest.raises(psycopg.OperationalError): + conn_cls.connect("dbname=nosuchdb") + + +@pytest.mark.slow +@pytest.mark.timing +def test_connect_timeout(conn_cls, deaf_port): + t0 = time.time() + with pytest.raises(psycopg.OperationalError, match="timeout expired"): + conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + elapsed = time.time() - t0 + assert elapsed == pytest.approx(1.0, abs=0.05) + + +def test_close(conn): + assert not conn.closed + assert not conn.broken + + cur = conn.cursor() + + conn.close() + assert conn.closed + assert not conn.broken + assert conn.pgconn.status == conn.ConnStatus.BAD + + conn.close() + assert conn.closed + assert conn.pgconn.status == conn.ConnStatus.BAD + + with pytest.raises(psycopg.OperationalError): + cur.execute("select 1") + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_broken(conn): + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]) + assert conn.closed + assert conn.broken + conn.close() + assert conn.closed + assert conn.broken + + +def test_cursor_closed(conn): + conn.close() + with pytest.raises(psycopg.OperationalError): + with conn.cursor("foo"): + pass + with pytest.raises(psycopg.OperationalError): + conn.cursor() + + +def test_connection_warn_close(conn_cls, dsn, recwarn): + conn = conn_cls.connect(dsn) + conn.close() + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + conn = conn_cls.connect(dsn) + del conn + assert "IDLE" in str(recwarn.pop(ResourceWarning).message) + + conn = conn_cls.connect(dsn) + conn.execute("select 1") + del conn + assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) + + conn = conn_cls.connect(dsn) + try: + conn.execute("select wat") + except Exception: + pass + del conn + assert "INERROR" in str(recwarn.pop(ResourceWarning).message) + + with conn_cls.connect(dsn) as conn: + pass + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.fixture +def testctx(svcconn): + svcconn.execute("create table if not exists testctx (id int primary key)") + svcconn.execute("delete from testctx") + return None + + +def test_context_commit(conn_cls, testctx, conn, dsn): + with conn: + with conn.cursor() as cur: + cur.execute("insert into testctx values (42)") + + assert conn.closed + assert not conn.broken + + with conn_cls.connect(dsn) as conn: + with conn.cursor() as cur: + cur.execute("select * from testctx") + assert cur.fetchall() == [(42,)] + + +def test_context_rollback(conn_cls, testctx, conn, dsn): + with pytest.raises(ZeroDivisionError): + with conn: + with conn.cursor() as cur: + cur.execute("insert into testctx values (42)") + 1 / 0 + + assert conn.closed + assert not conn.broken + + with conn_cls.connect(dsn) as conn: + with conn.cursor() as cur: + cur.execute("select * from testctx") + assert cur.fetchall() == [] + + +def test_context_close(conn): + with conn: + conn.execute("select 1") + conn.close() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_context_inerror_rollback_no_clobber(conn_cls, conn, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn2: + conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn: + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + assert not conn.pgconn.error_message + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.slow +def test_weakref(conn_cls, dsn): + conn = conn_cls.connect(dsn) + w = weakref.ref(conn) + conn.close() + del conn + gc_collect() + assert w() is None + + +def test_commit(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + conn.pgconn.exec_(b"insert into foo values (1)") + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.commit() + + +@pytest.mark.crdb_skip("deferrable") +def test_commit_error(conn): + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + conn.commit() + + conn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + +def test_rollback(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + conn.pgconn.exec_(b"insert into foo values (1)") + conn.rollback() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.ntuples == 0 + + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.rollback() + + +def test_auto_transaction(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +def test_auto_transaction_fail(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + with pytest.raises(psycopg.DatabaseError): + cur.execute("meh") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + with pytest.raises(psycopg.errors.InFailedSqlTransaction): + cur.execute("select 1") + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() is None + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +def test_autocommit(conn): + assert conn.autocommit is False + conn.autocommit = True + assert conn.autocommit + cur = conn.cursor() + assert cur.execute("select 1").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + conn.autocommit = "" + assert conn.autocommit is False # type: ignore[comparison-overlap] + conn.autocommit = "yeah" + assert conn.autocommit is True + + +def test_autocommit_connect(conn_cls, dsn): + conn = conn_cls.connect(dsn, autocommit=True) + assert conn.autocommit + conn.close() + + +def test_autocommit_intrans(conn): + cur = conn.cursor() + assert cur.execute("select 1").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + with pytest.raises(psycopg.ProgrammingError): + conn.autocommit = True + assert not conn.autocommit + + +def test_autocommit_inerror(conn): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.execute("meh") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + with pytest.raises(psycopg.ProgrammingError): + conn.autocommit = True + assert not conn.autocommit + + +def test_autocommit_unknown(conn): + conn.close() + assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN + with pytest.raises(psycopg.OperationalError): + conn.autocommit = True + assert not conn.autocommit + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("host=foo user=bar",), {}, "host=foo user=bar"), + (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + ( + ("host=foo port=5432",), + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + (("host=foo",), {"user": None}, "host=foo"), + ], +) +def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = conn_cls.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + with pytest.raises(exctype): + conn_cls.connect(*args, **kwargs) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_broken_connection(conn): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.execute("select pg_terminate_backend(pg_backend_pid())") + assert conn.closed + + +@pytest.mark.crdb_skip("do") +def test_notice_handlers(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + messages = [] + severities = [] + + def cb1(diag): + messages.append(diag.message_primary) + + def cb2(res): + raise Exception("hello from cb2") + + conn.add_notice_handler(cb1) + conn.add_notice_handler(cb2) + conn.add_notice_handler("the wrong thing") + conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized)) + + conn.pgconn.exec_(b"set client_min_messages to notice") + cur = conn.cursor() + cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql") + assert messages == ["hello notice"] + assert severities == ["NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + conn.remove_notice_handler(cb1) + conn.remove_notice_handler("the wrong thing") + cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql") + assert len(caplog.records) == 3 + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] + + with pytest.raises(ValueError): + conn.remove_notice_handler(cb1) + + +@pytest.mark.crdb_skip("notify") +def test_notify_handlers(conn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + conn.add_notify_handler(cb1) + conn.add_notify_handler(lambda n: nots2.append(n)) + + conn.autocommit = True + cur = conn.cursor() + cur.execute("listen foo") + cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == conn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + conn.remove_notify_handler(cb1) + cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == conn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + conn.remove_notify_handler(cb1) + + +def test_execute(conn): + cur = conn.execute("select %s, %s", [10, 20]) + assert cur.fetchone() == (10, 20) + assert cur.format == 0 + assert cur.pgresult.fformat(0) == 0 + + cur = conn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21}) + assert cur.fetchone() == (11, 21) + + cur = conn.execute("select 12, 22") + assert cur.fetchone() == (12, 22) + + +def test_execute_binary(conn): + cur = conn.execute("select %s, %s", [10, 20], binary=True) + assert cur.fetchone() == (10, 20) + assert cur.format == 1 + assert cur.pgresult.fformat(0) == 1 + + +def test_row_factory(conn_cls, dsn): + defaultconn = conn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row + defaultconn.close() + + conn = conn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory + + cur = conn.execute("select 'a' as ve") + assert cur.fetchone() == ["Ave"] + + with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1: + cur1.execute("select 1, 1, 2") + assert cur1.fetchall() == [{1, 2}] + + with conn.cursor(row_factory=tuple_row) as cur2: + cur2.execute("select 1, 1, 2") + assert cur2.fetchall() == [(1, 1, 2)] + + # TODO: maybe fix something to get rid of 'type: ignore' below. + conn.row_factory = tuple_row + cur3 = conn.execute("select 'vale'") + r = cur3.fetchone() + assert r and r == ("vale",) + conn.close() + + +def test_str(conn): + assert "[IDLE]" in str(conn) + conn.close() + assert "[BAD]" in str(conn) + + +def test_fileno(conn): + assert conn.fileno() == conn.pgconn.socket + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.fileno() + + +def test_cursor_factory(conn): + assert conn.cursor_factory is psycopg.Cursor + + class MyCursor(psycopg.Cursor[psycopg.rows.Row]): + pass + + conn.cursor_factory = MyCursor + with conn.cursor() as cur: + assert isinstance(cur, MyCursor) + + with conn.execute("select 1") as cur: + assert isinstance(cur, MyCursor) + + +def test_cursor_factory_connect(conn_cls, dsn): + class MyCursor(psycopg.Cursor[psycopg.rows.Row]): + pass + + with conn_cls.connect(dsn, cursor_factory=MyCursor) as conn: + assert conn.cursor_factory is MyCursor + cur = conn.cursor() + assert type(cur) is MyCursor + + +def test_server_cursor_factory(conn): + assert conn.server_cursor_factory is psycopg.ServerCursor + + class MyServerCursor(psycopg.ServerCursor[psycopg.rows.Row]): + pass + + conn.server_cursor_factory = MyServerCursor + with conn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor) + + +@dataclass +class ParamDef: + name: str + guc: str + values: List[Any] + + +param_isolation = ParamDef( + name="isolation_level", + guc="isolation", + values=list(psycopg.IsolationLevel), +) +param_read_only = ParamDef( + name="read_only", + guc="read_only", + values=[True, False], +) +param_deferrable = ParamDef( + name="deferrable", + guc="deferrable", + values=[True, False], +) + +# Map Python values to Postgres values for the tx_params possible values +tx_values_map = { + v.name.lower().replace("_", " "): v.value for v in psycopg.IsolationLevel +} +tx_values_map["on"] = True +tx_values_map["off"] = False + + +tx_params = [ + param_isolation, + param_read_only, + pytest.param(param_deferrable, marks=pytest.mark.crdb_skip("deferrable")), +] +tx_params_isolation = [ + pytest.param( + param_isolation, + id="isolation_level", + marks=pytest.mark.crdb("skip", reason="transaction isolation"), + ), + pytest.param( + param_read_only, id="read_only", marks=pytest.mark.crdb_skip("begin_read_only") + ), + pytest.param( + param_deferrable, id="deferrable", marks=pytest.mark.crdb_skip("deferrable") + ), +] + + +@pytest.mark.parametrize("param", tx_params) +def test_transaction_param_default(conn, param): + assert getattr(conn, param.name) is None + current, default = conn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ).fetchone() + assert current == default + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +def test_set_transaction_param_implicit(conn, param, autocommit): + conn.autocommit = autocommit + for value in param.values: + setattr(conn, param.name, value) + pgval, default = conn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ).fetchone() + if autocommit: + assert pgval == default + else: + assert tx_values_map[pgval] == value + conn.rollback() + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +def test_set_transaction_param_block(conn, param, autocommit): + conn.autocommit = autocommit + for value in param.values: + setattr(conn, param.name, value) + with conn.transaction(): + pgval = conn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ).fetchone()[0] + assert tx_values_map[pgval] == value + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_implicit(conn, param): + conn.execute("select 1") + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_block(conn, param): + with conn.transaction(): + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.parametrize("param", tx_params) +def test_set_transaction_param_not_intrans_external(conn, param): + conn.autocommit = True + conn.execute("begin") + with pytest.raises(psycopg.ProgrammingError): + setattr(conn, param.name, param.values[0]) + + +@pytest.mark.crdb("skip", reason="transaction isolation") +def test_set_transaction_param_all(conn): + params: List[Any] = tx_params[:] + params[2] = params[2].values[0] + + for param in params: + value = param.values[0] + setattr(conn, param.name, value) + + for param in params: + pgval = conn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ).fetchone()[0] + assert tx_values_map[pgval] == value + + +def test_set_transaction_param_strange(conn): + for val in ("asdf", 0, 5): + with pytest.raises(ValueError): + conn.isolation_level = val + + conn.isolation_level = psycopg.IsolationLevel.SERIALIZABLE.value + assert conn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE + + conn.read_only = 1 + assert conn.read_only is True + + conn.deferrable = 0 + assert conn.deferrable is False + + +conninfo_params_timeout = [ + ( + "", + {"dbname": "mydb", "connect_timeout": None}, + ({"dbname": "mydb"}, None), + ), + ( + "", + {"dbname": "mydb", "connect_timeout": 1}, + ({"dbname": "mydb", "connect_timeout": "1"}, 1), + ), + ( + "dbname=postgres", + {}, + ({"dbname": "postgres"}, None), + ), + ( + "dbname=postgres connect_timeout=2", + {}, + ({"dbname": "postgres", "connect_timeout": "2"}, 2), + ), + ( + "postgresql:///postgres?connect_timeout=2", + {"connect_timeout": 10}, + ({"dbname": "postgres", "connect_timeout": "10"}, 10), + ), +] + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +def test_get_connection_params(conn_cls, dsn, kwargs, exp): + params = conn_cls._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params.get("connect_timeout") == exp[1] + + +def test_connect_context(conn_cls, dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = conn_cls.connect(dsn, context=ctx) + + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" + conn.close() + + +def test_connect_context_copy(conn_cls, dsn, conn): + conn.adapters.register_dumper(str, make_bin_dumper("b")) + conn.adapters.register_dumper(str, make_dumper("t")) + + conn2 = conn_cls.connect(dsn, context=conn) + + cur = conn2.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn2.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" + conn2.close() + + +def test_cancel_closed(conn): + conn.close() + conn.cancel() diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py new file mode 100644 index 0000000..1288a6c --- /dev/null +++ b/tests/test_connection_async.py @@ -0,0 +1,751 @@ +import time +import pytest +import logging +import weakref +from typing import List, Any + +import psycopg +from psycopg import Notify, errors as e +from psycopg.rows import tuple_row +from psycopg.conninfo import conninfo_to_dict, make_conninfo + +from .utils import gc_collect +from .test_cursor import my_row_factory +from .test_connection import tx_params, tx_params_isolation, tx_values_map +from .test_connection import conninfo_params_timeout +from .test_connection import testctx # noqa: F401 # fixture +from .test_adapt import make_bin_dumper, make_dumper +from .test_conninfo import fake_resolve # noqa: F401 + +pytestmark = pytest.mark.asyncio + + +async def test_connect(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + await conn.close() + + +async def test_connect_bad(aconn_cls): + with pytest.raises(psycopg.OperationalError): + await aconn_cls.connect("dbname=nosuchdb") + + +async def test_connect_str_subclass(aconn_cls, dsn): + class MyString(str): + pass + + conn = await aconn_cls.connect(MyString(dsn)) + assert not conn.closed + assert conn.pgconn.status == conn.ConnStatus.OK + await conn.close() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_connect_timeout(aconn_cls, deaf_port): + t0 = time.time() + with pytest.raises(psycopg.OperationalError, match="timeout expired"): + await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + elapsed = time.time() - t0 + assert elapsed == pytest.approx(1.0, abs=0.05) + + +async def test_close(aconn): + assert not aconn.closed + assert not aconn.broken + + cur = aconn.cursor() + + await aconn.close() + assert aconn.closed + assert not aconn.broken + assert aconn.pgconn.status == aconn.ConnStatus.BAD + + await aconn.close() + assert aconn.closed + assert aconn.pgconn.status == aconn.ConnStatus.BAD + + with pytest.raises(psycopg.OperationalError): + await cur.execute("select 1") + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_broken(aconn): + with pytest.raises(psycopg.OperationalError): + await aconn.execute( + "select pg_terminate_backend(%s)", [aconn.pgconn.backend_pid] + ) + assert aconn.closed + assert aconn.broken + await aconn.close() + assert aconn.closed + assert aconn.broken + + +async def test_cursor_closed(aconn): + await aconn.close() + with pytest.raises(psycopg.OperationalError): + async with aconn.cursor("foo"): + pass + aconn.cursor("foo") + with pytest.raises(psycopg.OperationalError): + aconn.cursor() + + +async def test_connection_warn_close(aconn_cls, dsn, recwarn): + conn = await aconn_cls.connect(dsn) + await conn.close() + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + conn = await aconn_cls.connect(dsn) + del conn + assert "IDLE" in str(recwarn.pop(ResourceWarning).message) + + conn = await aconn_cls.connect(dsn) + await conn.execute("select 1") + del conn + assert "INTRANS" in str(recwarn.pop(ResourceWarning).message) + + conn = await aconn_cls.connect(dsn) + try: + await conn.execute("select wat") + except Exception: + pass + del conn + assert "INERROR" in str(recwarn.pop(ResourceWarning).message) + + async with await aconn_cls.connect(dsn) as conn: + pass + del conn + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.usefixtures("testctx") +async def test_context_commit(aconn_cls, aconn, dsn): + async with aconn: + async with aconn.cursor() as cur: + await cur.execute("insert into testctx values (42)") + + assert aconn.closed + assert not aconn.broken + + async with await aconn_cls.connect(dsn) as aconn: + async with aconn.cursor() as cur: + await cur.execute("select * from testctx") + assert await cur.fetchall() == [(42,)] + + +@pytest.mark.usefixtures("testctx") +async def test_context_rollback(aconn_cls, aconn, dsn): + with pytest.raises(ZeroDivisionError): + async with aconn: + async with aconn.cursor() as cur: + await cur.execute("insert into testctx values (42)") + 1 / 0 + + assert aconn.closed + assert not aconn.broken + + async with await aconn_cls.connect(dsn) as aconn: + async with aconn.cursor() as cur: + await cur.execute("select * from testctx") + assert await cur.fetchall() == [] + + +async def test_context_close(aconn): + async with aconn: + await aconn.execute("select 1") + await aconn.close() + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_context_inerror_rollback_no_clobber(aconn_cls, conn, dsn, caplog): + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn2: + await conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn: + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + assert not conn.pgconn.error_message + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.slow +async def test_weakref(aconn_cls, dsn): + conn = await aconn_cls.connect(dsn) + w = weakref.ref(conn) + await conn.close() + del conn + gc_collect() + assert w() is None + + +async def test_commit(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + aconn.pgconn.exec_(b"insert into foo values (1)") + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + await aconn.close() + with pytest.raises(psycopg.OperationalError): + await aconn.commit() + + +@pytest.mark.crdb_skip("deferrable") +async def test_commit_error(aconn): + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + await aconn.commit() + + await aconn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + cur = await aconn.execute("select 1") + assert await cur.fetchone() == (1,) + + +async def test_rollback(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + aconn.pgconn.exec_(b"insert into foo values (1)") + await aconn.rollback() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.ntuples == 0 + + await aconn.close() + with pytest.raises(psycopg.OperationalError): + await aconn.rollback() + + +async def test_auto_transaction(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await cur.execute("insert into foo values (1)") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + await cur.execute("select * from foo") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_auto_transaction_fail(aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await cur.execute("insert into foo values (1)") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + with pytest.raises(psycopg.DatabaseError): + await cur.execute("meh") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + await aconn.commit() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + await cur.execute("select * from foo") + assert await cur.fetchone() is None + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_autocommit(aconn): + assert aconn.autocommit is False + with pytest.raises(AttributeError): + aconn.autocommit = True + assert not aconn.autocommit + + await aconn.set_autocommit(True) + assert aconn.autocommit + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + await aconn.set_autocommit("") + assert aconn.autocommit is False + await aconn.set_autocommit("yeah") + assert aconn.autocommit is True + + +async def test_autocommit_connect(aconn_cls, dsn): + aconn = await aconn_cls.connect(dsn, autocommit=True) + assert aconn.autocommit + await aconn.close() + + +async def test_autocommit_intrans(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + with pytest.raises(psycopg.ProgrammingError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +async def test_autocommit_inerror(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("meh") + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + with pytest.raises(psycopg.ProgrammingError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +async def test_autocommit_unknown(aconn): + await aconn.close() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN + with pytest.raises(psycopg.OperationalError): + await aconn.set_autocommit(True) + assert not aconn.autocommit + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), + (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), + ( + ("dbname=foo port=5432",), + {"dbname": "qux", "user": "joe"}, + "dbname=qux user=joe port=5432", + ), + (("dbname=foo",), {"user": None}, "dbname=foo"), + ], +) +async def test_connect_args( + aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want +): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + setpgenv({}) + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = await aconn_cls.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + await conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + with pytest.raises(exctype): + await aconn_cls.connect(*args, **kwargs) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_broken_connection(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("select pg_terminate_backend(pg_backend_pid())") + assert aconn.closed + + +@pytest.mark.crdb_skip("do") +async def test_notice_handlers(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + messages = [] + severities = [] + + def cb1(diag): + messages.append(diag.message_primary) + + def cb2(res): + raise Exception("hello from cb2") + + aconn.add_notice_handler(cb1) + aconn.add_notice_handler(cb2) + aconn.add_notice_handler("the wrong thing") + aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized)) + + aconn.pgconn.exec_(b"set client_min_messages to notice") + cur = aconn.cursor() + await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql") + assert messages == ["hello notice"] + assert severities == ["NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + aconn.remove_notice_handler(cb1) + aconn.remove_notice_handler("the wrong thing") + await cur.execute( + "do $$begin raise warning 'hello warning'; end$$ language plpgsql" + ) + assert len(caplog.records) == 3 + assert messages == ["hello notice"] + assert severities == ["NOTICE", "WARNING"] + + with pytest.raises(ValueError): + aconn.remove_notice_handler(cb1) + + +@pytest.mark.crdb_skip("notify") +async def test_notify_handlers(aconn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + aconn.add_notify_handler(cb1) + aconn.add_notify_handler(lambda n: nots2.append(n)) + + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + await cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == aconn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + aconn.remove_notify_handler(cb1) + await cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == aconn.pgconn.backend_pid + + with pytest.raises(ValueError): + aconn.remove_notify_handler(cb1) + + +async def test_execute(aconn): + cur = await aconn.execute("select %s, %s", [10, 20]) + assert await cur.fetchone() == (10, 20) + assert cur.format == 0 + assert cur.pgresult.fformat(0) == 0 + + cur = await aconn.execute("select %(a)s, %(b)s", {"a": 11, "b": 21}) + assert await cur.fetchone() == (11, 21) + + cur = await aconn.execute("select 12, 22") + assert await cur.fetchone() == (12, 22) + + +async def test_execute_binary(aconn): + cur = await aconn.execute("select %s, %s", [10, 20], binary=True) + assert await cur.fetchone() == (10, 20) + assert cur.format == 1 + assert cur.pgresult.fformat(0) == 1 + + +async def test_row_factory(aconn_cls, dsn): + defaultconn = await aconn_cls.connect(dsn) + assert defaultconn.row_factory is tuple_row + await defaultconn.close() + + conn = await aconn_cls.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory is my_row_factory + + cur = await conn.execute("select 'a' as ve") + assert await cur.fetchone() == ["Ave"] + + async with conn.cursor(row_factory=lambda c: lambda t: set(t)) as cur1: + await cur1.execute("select 1, 1, 2") + assert await cur1.fetchall() == [{1, 2}] + + async with conn.cursor(row_factory=tuple_row) as cur2: + await cur2.execute("select 1, 1, 2") + assert await cur2.fetchall() == [(1, 1, 2)] + + # TODO: maybe fix something to get rid of 'type: ignore' below. + conn.row_factory = tuple_row + cur3 = await conn.execute("select 'vale'") + r = await cur3.fetchone() + assert r and r == ("vale",) + await conn.close() + + +async def test_str(aconn): + assert "[IDLE]" in str(aconn) + await aconn.close() + assert "[BAD]" in str(aconn) + + +async def test_fileno(aconn): + assert aconn.fileno() == aconn.pgconn.socket + await aconn.close() + with pytest.raises(psycopg.OperationalError): + aconn.fileno() + + +async def test_cursor_factory(aconn): + assert aconn.cursor_factory is psycopg.AsyncCursor + + class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]): + pass + + aconn.cursor_factory = MyCursor + async with aconn.cursor() as cur: + assert isinstance(cur, MyCursor) + + async with (await aconn.execute("select 1")) as cur: + assert isinstance(cur, MyCursor) + + +async def test_cursor_factory_connect(aconn_cls, dsn): + class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]): + pass + + async with await aconn_cls.connect(dsn, cursor_factory=MyCursor) as conn: + assert conn.cursor_factory is MyCursor + cur = conn.cursor() + assert type(cur) is MyCursor + + +async def test_server_cursor_factory(aconn): + assert aconn.server_cursor_factory is psycopg.AsyncServerCursor + + class MyServerCursor(psycopg.AsyncServerCursor[psycopg.rows.Row]): + pass + + aconn.server_cursor_factory = MyServerCursor + async with aconn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor) + + +@pytest.mark.parametrize("param", tx_params) +async def test_transaction_param_default(aconn, param): + assert getattr(aconn, param.name) is None + cur = await aconn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ) + current, default = await cur.fetchone() + assert current == default + + +@pytest.mark.parametrize("param", tx_params) +async def test_transaction_param_readonly_property(aconn, param): + with pytest.raises(AttributeError): + setattr(aconn, param.name, None) + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +async def test_set_transaction_param_implicit(aconn, param, autocommit): + await aconn.set_autocommit(autocommit) + for value in param.values: + await getattr(aconn, f"set_{param.name}")(value) + cur = await aconn.execute( + "select current_setting(%s), current_setting(%s)", + [f"transaction_{param.guc}", f"default_transaction_{param.guc}"], + ) + pgval, default = await cur.fetchone() + if autocommit: + assert pgval == default + else: + assert tx_values_map[pgval] == value + await aconn.rollback() + + +@pytest.mark.parametrize("autocommit", [True, False]) +@pytest.mark.parametrize("param", tx_params_isolation) +async def test_set_transaction_param_block(aconn, param, autocommit): + await aconn.set_autocommit(autocommit) + for value in param.values: + await getattr(aconn, f"set_{param.name}")(value) + async with aconn.transaction(): + cur = await aconn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ) + pgval = (await cur.fetchone())[0] + assert tx_values_map[pgval] == value + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_implicit(aconn, param): + await aconn.execute("select 1") + value = param.values[0] + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_block(aconn, param): + value = param.values[0] + async with aconn.transaction(): + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.parametrize("param", tx_params) +async def test_set_transaction_param_not_intrans_external(aconn, param): + value = param.values[0] + await aconn.set_autocommit(True) + await aconn.execute("begin") + with pytest.raises(psycopg.ProgrammingError): + await getattr(aconn, f"set_{param.name}")(value) + + +@pytest.mark.crdb("skip", reason="transaction isolation") +async def test_set_transaction_param_all(aconn): + params: List[Any] = tx_params[:] + params[2] = params[2].values[0] + + for param in params: + value = param.values[0] + await getattr(aconn, f"set_{param.name}")(value) + + for param in params: + cur = await aconn.execute( + "select current_setting(%s)", [f"transaction_{param.guc}"] + ) + pgval = (await cur.fetchone())[0] + assert tx_values_map[pgval] == value + + +async def test_set_transaction_param_strange(aconn): + for val in ("asdf", 0, 5): + with pytest.raises(ValueError): + await aconn.set_isolation_level(val) + + await aconn.set_isolation_level(psycopg.IsolationLevel.SERIALIZABLE.value) + assert aconn.isolation_level is psycopg.IsolationLevel.SERIALIZABLE + + await aconn.set_read_only(1) + assert aconn.read_only is True + + await aconn.set_deferrable(0) + assert aconn.deferrable is False + + +@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) +async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv): + setpgenv({}) + params = await aconn_cls._get_connection_params(dsn, **kwargs) + conninfo = make_conninfo(**params) + assert conninfo_to_dict(conninfo) == exp[0] + assert params["connect_timeout"] == exp[1] + + +async def test_connect_context_adapters(aconn_cls, dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = await aconn_cls.connect(dsn, context=ctx) + + cur = await conn.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await conn.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob" + await conn.close() + + +async def test_connect_context_copy(aconn_cls, dsn, aconn): + aconn.adapters.register_dumper(str, make_bin_dumper("b")) + aconn.adapters.register_dumper(str, make_dumper("t")) + + aconn2 = await aconn_cls.connect(dsn, context=aconn) + + cur = await aconn2.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await aconn2.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob" + await aconn2.close() + + +async def test_cancel_closed(aconn): + await aconn.close() + aconn.cancel() + + +async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811 + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + await psycopg.AsyncConnection.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert conninfo_to_dict(got[0]) == want diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py new file mode 100644 index 0000000..e2c2c01 --- /dev/null +++ b/tests/test_conninfo.py @@ -0,0 +1,450 @@ +import socket +import asyncio +import datetime as dt + +import pytest + +import psycopg +from psycopg import ProgrammingError +from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from psycopg.conninfo import resolve_hostaddr_async +from psycopg._encodings import pg2pyenc + +from .fix_crdb import crdb_encoding + +snowman = "\u2603" + + +class MyString(str): + pass + + +@pytest.mark.parametrize( + "conninfo, kwargs, exp", + [ + ("", {}, ""), + ("dbname=foo", {}, "dbname=foo"), + ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"), + ("dbname=sony", {"password": ""}, "dbname=sony password="), + ("dbname=foo", {"dbname": "bar"}, "dbname=bar"), + ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"), + ("", {"dbname": "foo"}, "dbname=foo"), + ("", {"dbname": "foo", "user": None}, "dbname=foo"), + ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"), + ("", {"dbname": "a'b"}, r"dbname='a\'b'"), + (f"dbname={snowman}", {}, f"dbname={snowman}"), + ("", {"dbname": snowman}, f"dbname={snowman}"), + ( + "postgresql://host1/test", + {"host": "host2"}, + "dbname=test host=host2", + ), + (MyString(""), {}, ""), + ], +) +def test_make_conninfo(conninfo, kwargs, exp): + out = make_conninfo(conninfo, **kwargs) + assert conninfo_to_dict(out) == conninfo_to_dict(exp) + + +@pytest.mark.parametrize( + "conninfo, kwargs", + [ + ("hello", {}), + ("dbname=foo bar", {}), + ("foo=bar", {}), + ("dbname=foo", {"bar": "baz"}), + ("postgresql://tester:secret@/test?port=5433=x", {}), + (f"{snowman}={snowman}", {}), + ], +) +def test_make_conninfo_bad(conninfo, kwargs): + with pytest.raises(ProgrammingError): + make_conninfo(conninfo, **kwargs) + + +@pytest.mark.parametrize( + "conninfo, exp", + [ + ("", {}), + ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}), + ("dbname=sony password=", {"dbname": "sony", "password": ""}), + ("dbname='foo bar'", {"dbname": "foo bar"}), + ("dbname='a\"b'", {"dbname": 'a"b'}), + (r"dbname='a\'b'", {"dbname": "a'b"}), + (r"dbname='a\\b'", {"dbname": r"a\b"}), + (f"dbname={snowman}", {"dbname": snowman}), + ( + "postgresql://tester:secret@/test?port=5433", + { + "user": "tester", + "password": "secret", + "dbname": "test", + "port": "5433", + }, + ), + ], +) +def test_conninfo_to_dict(conninfo, exp): + assert conninfo_to_dict(conninfo) == exp + + +def test_no_munging(): + dsnin = "dbname=a host=b user=c password=d" + dsnout = make_conninfo(dsnin) + assert dsnin == dsnout + + +class TestConnectionInfo: + @pytest.mark.parametrize( + "attr", + [("dbname", "db"), "host", "hostaddr", "user", "password", "options"], + ) + def test_attrs(self, conn, attr): + if isinstance(attr, tuple): + info_attr, pgconn_attr = attr + else: + info_attr = pgconn_attr = attr + + if info_attr == "hostaddr" and psycopg.pq.version() < 120000: + pytest.skip("hostaddr not supported on libpq < 12") + + info_val = getattr(conn.info, info_attr) + pgconn_val = getattr(conn.pgconn, pgconn_attr).decode() + assert info_val == pgconn_val + + conn.close() + with pytest.raises(psycopg.OperationalError): + getattr(conn.info, info_attr) + + @pytest.mark.libpq("< 12") + def test_hostaddr_not_supported(self, conn): + with pytest.raises(psycopg.NotSupportedError): + conn.info.hostaddr + + def test_port(self, conn): + assert conn.info.port == int(conn.pgconn.port.decode()) + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.port + + def test_get_params(self, conn, dsn): + info = conn.info.get_parameters() + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert info.get(k) == v + else: + assert k not in info + + def test_dsn(self, conn, dsn): + dsn = conn.info.dsn + assert "password" not in dsn + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert f"{k}=" in dsn + + def test_get_params_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name" not in conn.info.get_parameters() + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert conn.info.get_parameters()["application_name"] == "hello test" + + def test_dsn_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name=" not in conn.info.dsn + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert "application_name='hello test'" in conn.info.dsn + + def test_status(self, conn): + assert conn.info.status.name == "OK" + conn.close() + assert conn.info.status.name == "BAD" + + def test_transaction_status(self, conn): + assert conn.info.transaction_status.name == "IDLE" + conn.close() + assert conn.info.transaction_status.name == "UNKNOWN" + + @pytest.mark.pipeline + def test_pipeline_status(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + with conn.pipeline(): + assert conn.info.pipeline_status + assert conn.info.pipeline_status.name == "ON" + + @pytest.mark.libpq("< 14") + def test_pipeline_status_no_pipeline(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + + def test_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.get_parameters() + assert info.get_parameters()["dbname"] == info.dbname + + def test_dsn_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.dsn + assert f"dbname={info.dbname}" in info.dsn + + def test_parameter_status(self, conn): + assert conn.info.parameter_status("nosuchparam") is None + tz = conn.info.parameter_status("TimeZone") + assert tz and isinstance(tz, str) + assert tz == conn.execute("show timezone").fetchone()[0] + + @pytest.mark.crdb("skip") + def test_server_version(self, conn): + assert conn.info.server_version == conn.pgconn.server_version + + def test_error_message(self, conn): + assert conn.info.error_message == "" + with pytest.raises(psycopg.ProgrammingError) as ex: + conn.execute("wat") + + assert conn.info.error_message + assert str(ex.value) in conn.info.error_message + assert ex.value.diag.severity in conn.info.error_message + + conn.close() + assert "NULL" in conn.info.error_message + + @pytest.mark.crdb_skip("backend pid") + def test_backend_pid(self, conn): + assert conn.info.backend_pid + assert conn.info.backend_pid == conn.pgconn.backend_pid + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.backend_pid + + def test_timezone(self, conn): + conn.execute("set timezone to 'Europe/Rome'") + tz = conn.info.timezone + assert isinstance(tz, dt.tzinfo) + offset = tz.utcoffset(dt.datetime(2000, 1, 1)) + assert offset and offset.total_seconds() == 3600 + offset = tz.utcoffset(dt.datetime(2000, 7, 1)) + assert offset and offset.total_seconds() == 7200 + + @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones") + def test_timezone_warn(self, conn, caplog): + conn.execute("set timezone to 'FOOBAR0'") + assert len(caplog.records) == 0 + tz = conn.info.timezone + assert tz == dt.timezone.utc + assert len(caplog.records) == 1 + assert "FOOBAR0" in caplog.records[0].message + + conn.info.timezone + assert len(caplog.records) == 1 + + conn.execute("set timezone to 'FOOBAAR0'") + assert len(caplog.records) == 1 + conn.info.timezone + assert len(caplog.records) == 2 + assert "FOOBAAR0" in caplog.records[1].message + + def test_encoding(self, conn): + enc = conn.execute("show client_encoding").fetchone()[0] + assert conn.info.encoding == pg2pyenc(enc.encode()) + + @pytest.mark.crdb("skip", reason="encoding not normalized") + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + ("eucjp", "EUC_JP", "euc_jp"), + ("euc-jp", "EUC_JP", "euc_jp"), + ("latin9", "LATIN9", "iso8859-15"), + ], + ) + def test_normalize_encoding(self, conn, enc, out, codec): + conn.execute("select set_config('client_encoding', %s, false)", [enc]) + assert conn.info.parameter_status("client_encoding") == out + assert conn.info.encoding == codec + + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + crdb_encoding("eucjp", "EUC_JP", "euc_jp"), + crdb_encoding("euc-jp", "EUC_JP", "euc_jp"), + ], + ) + def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec): + monkeypatch.setenv("PGCLIENTENCODING", enc) + with conn_cls.connect(dsn) as conn: + clienc = conn.info.parameter_status("client_encoding") + assert clienc + if conn.info.vendor == "PostgreSQL": + assert clienc == out + else: + assert clienc.replace("-", "").replace("_", "").upper() == out + assert conn.info.encoding == codec + + @pytest.mark.crdb_skip("encoding") + def test_set_encoding_unsupported(self, conn): + cur = conn.cursor() + cur.execute("set client_encoding to EUC_TW") + with pytest.raises(psycopg.NotSupportedError): + cur.execute("select 'x'") + + def test_vendor(self, conn): + assert conn.info.vendor + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", "", None), + ("host='' user=bar", "host='' user=bar", None), + ( + "host=127.0.0.1 user=bar", + "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + "host=foo.com port=5432", + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_no_resolve( + setpgenv, conninfo, want, env, fail_resolve +): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=foo.com,qux.com port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + None, + ), + ( + "host=foo.com,nosuchhost.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com, port=5432,5433", + "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + None, + ), + ( + "host=nosuchhost.com,foo.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + {}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await resolve_hostaddr_async(params) + + +@pytest.fixture +async def fake_resolve(monkeypatch): + fake_hosts = { + "localhost": "127.0.0.1", + "foo.com": "1.1.1.1", + "qux.com": "2.2.2.2", + } + + async def fake_getaddrinfo(host, port, **kwargs): + assert isinstance(port, int) or (isinstance(port, str) and port.isdigit()) + try: + addr = fake_hosts[host] + except KeyError: + raise OSError(f"unknown test host: {host}") + else: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))] + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + +@pytest.fixture +async def fail_resolve(monkeypatch): + async def fail_getaddrinfo(host, port, **kwargs): + pytest.fail(f"shouldn't try to resolve {host}") + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo) diff --git a/tests/test_copy.py b/tests/test_copy.py new file mode 100644 index 0000000..17cf2fc --- /dev/null +++ b/tests/test_copy.py @@ -0,0 +1,889 @@ +import string +import struct +import hashlib +from io import BytesIO, StringIO +from random import choice, randrange +from itertools import cycle + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter +from psycopg.adapt import PyFormat +from psycopg.types import TypeInfo +from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 + +from .utils import eur, gc_collect, gc_count + +pytestmark = pytest.mark.crdb_skip("copy") + +sample_records = [(40010, 40020, "hello"), (40040, None, "world")] +sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')" +sample_tabledef = "col1 serial primary key, col2 int, data text" + +sample_text = b"""\ +40010\t40020\thello +40040\t\\N\tworld +""" + +sample_binary_str = """ +5047 434f 5059 0aff 0d0a 00 +00 0000 0000 0000 00 +00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f + +0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64 + +ff ff +""" + +sample_binary_rows = [ + bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n") +] +sample_binary = b"".join(sample_binary_rows) + +special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"} + + +@pytest.mark.parametrize("format", Format) +def test_copy_out_read(conn, format): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + for row in want: + got = copy.read() + assert got == row + assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE + + assert copy.read() == b"" + assert copy.read() == b"" + + assert copy.read() == b"" + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_copy_out_iter(conn, format, row_factory): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + rf = getattr(psycopg.rows, row_factory) + cur = conn.cursor(row_factory=rf) + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + assert list(copy) == want + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_copy_out_no_result(conn, format, row_factory): + rf = getattr(psycopg.rows, row_factory) + cur = conn.cursor(row_factory=rf) + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"): + with pytest.raises(e.ProgrammingError): + cur.fetchone() + + +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +def test_copy_out_param(conn, ph, params): + cur = conn.cursor() + with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert list(copy.rows()) == [(i + 1,) for i in range(10)] + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("typetype", ["names", "oids"]) +def test_read_rows(conn, format, typetype): + cur = conn.cursor() + with cur.copy( + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" + ) as copy: + copy.set_types(["int4", "text", "float8[]"]) + row = copy.read_row() + assert copy.read_row() is None + + assert row == (10, "hello", [0.0, 1.0]) + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +def test_rows(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + rows = list(copy.rows()) + + assert rows == sample_records + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_set_custom_type(conn, hstore): + command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout""" + cur = conn.cursor() + + with cur.copy(command) as copy: + rows = list(copy.rows()) + + assert rows == [('"a"=>"1", "b"=>"2"',)] + + register_hstore(TypeInfo.fetch(conn, "hstore"), cur) + with cur.copy(command) as copy: + copy.set_types(["hstore"]) + rows = list(copy.rows()) + + assert rows == [({"a": "1", "b": "2"},)] + + +@pytest.mark.parametrize("format", Format) +def test_copy_out_allchars(conn, format): + cur = conn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + conn.execute("set client_encoding to utf8") + rows = [] + query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format( + chars, sql.SQL(format.name) + ) + with cur.copy(query) as copy: + copy.set_types(["text"]) + while True: + row = copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", Format) +def test_read_row_notypes(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + rows = [] + while True: + row = copy.read_row() + if not row: + break + rows.append(row) + + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("format", Format) +def test_rows_notypes(conn, format): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + rows = list(copy.rows()) + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", Format) +def test_copy_out_badntypes(conn, format, err): + cur = conn.cursor() + with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + copy.read_row() + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_bad_result(conn): + conn.autocommit = True + + cur = conn.cursor() + + with pytest.raises(e.SyntaxError): + with cur.copy("wat"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("select 1"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("reset timezone"): + pass + + with pytest.raises(e.ProgrammingError): + with cur.copy("copy (select 1) to stdout; select 1") as copy: + list(copy) + + with pytest.raises(e.ProgrammingError): + with cur.copy("select 1; copy (select 1) to stdout"): + pass + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin (format binary)") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +def test_copy_big_size_block(conn, pytype): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.parametrize("format", Format) +def test_subclass_adapter(conn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import ( # type: ignore[no-redef] + StrBinaryDumper as BaseDumper, + ) + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + conn.adapters.register_dumper(str, MyStrDumper) + + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy: + copy.write_row(("hello",)) + + rec = cur.execute("select data from copy_in").fetchone() + assert rec[0] == "hellohello" + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_error_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin (format text)") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_out_error_with_copy_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_copy_out_error_with_copy_not_finished(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy: + copy.read_row() + 1 / 0 + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_out_server_error(conn): + cur = conn.cursor() + with pytest.raises(e.DivisionByZero): + with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + for block in copy: + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int, data text") + + with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + conn.execute("set client_encoding to utf8") + with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +def test_copy_in_format(conn): + file = BytesIO() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with Copy(cur, writer=FileWriter(file)) as copy: + for i in range(1, 256): + copy.write_row((i, chr(i))) + + file.seek(0) + rows = file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_file_writer(conn, format, buffer): + file = BytesIO() + conn.execute("set client_encoding to utf8") + cur = conn.cursor() + with Copy(cur, binary=format, writer=FileWriter(file)) as copy: + for record in sample_records: + copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + +@pytest.mark.slow +def test_copy_from_to(conn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + f = BytesIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +def test_copy_from_to_bytes(conn, pytype): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(pytype(block.encode())) + + gen.assert_data() + + f = BytesIO() + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +def test_copy_from_insane_size(conn): + # Trying to trigger a "would block" error + gen = DataGenerator( + conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + gen.ensure_table() + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + + gen.assert_data() + + +def test_copy_rowcount(conn): + gen = DataGenerator(conn, nrecs=3, srec=10) + gen.ensure_table() + + cur = conn.cursor() + with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(conn, nrecs=2, srec=10, offset=3) + with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == 2 + + with cur.copy("copy copy_in to stdout") as copy: + for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + copy.write_row(rec) + assert cur.rowcount == -1 + + +def test_copy_query(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + assert cur._query.query == b"copy (select 1) to stdout" + assert not cur._query.params + list(copy) + + +def test_cant_reenter(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + list(copy) + + with pytest.raises(TypeError): + with copy: + list(copy) + + +def test_str(conn): + cur = conn.cursor() + with cur.copy("copy (select 1) to stdout") as copy: + assert "[ACTIVE]" in str(copy) + list(copy) + + assert "[INTRANS]" in str(copy) + + +def test_description(conn): + with conn.cursor() as cur: + with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy: + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + list(copy.rows()) + + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_worker_life(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=QueuedLibpqDriver(cur) + ) as copy: + assert not copy.writer._worker + copy.write(globals()[buffer]) + assert copy.writer._worker + + assert not copy.writer._worker + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_worker_error_propagated(conn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = conn.cursor() + cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + with cur.copy("copy wat from stdin", writer=QueuedLibpqDriver(cur)) as copy: + copy.write("a,b") + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +def test_connection_writer(conn, format, buffer): + cur = conn.cursor() + writer = LibpqWriter(cur) + + ensure_table(cur, sample_tabledef) + with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) +def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + stmt = sql.SQL( + "copy (select {} from {} order by id) to stdout (format {})" + ).format( + sql.SQL(", ").join(faker.fields_names), + faker.table_name, + sql.SQL(fmt.name), + ) + + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + + if method == "read": + while True: + tmp = copy.read() + if not tmp: + break + elif method == "iter": + list(copy) + elif method == "row": + while True: + tmp = copy.read_row() + if tmp is None: + break + elif method == "rows": + list(copy.rows()) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin (format {})").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL(fmt.name), + ) + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + copy.write_row(row) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["row", "block", "binary"]) +def test_copy_table_across(conn_cls, dsn, faker, mode): + faker.choose_schema(ncols=20) + faker.make_records(20) + + with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2: + faker.table_name = sql.Identifier("copy_src") + conn1.execute(faker.drop_stmt) + conn1.execute(faker.create_stmt) + conn1.cursor().executemany(faker.insert_stmt, faker.records) + + faker.table_name = sql.Identifier("copy_tgt") + conn2.execute(faker.drop_stmt) + conn2.execute(faker.create_stmt) + + fmt = "(format binary)" if mode == "binary" else "" + with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1: + with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2: + if mode == "row": + for row in copy1.rows(): + copy2.write_row(row) + else: + for data in copy1: + copy2.write(data) + + recs = conn2.execute(faker.select_stmt).fetchall() + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +def py_to_raw(item, fmt): + """Convert from Python type to the expected result from the db""" + if fmt == Format.TEXT: + if isinstance(item, int): + return str(item) + else: + if isinstance(item, int): + # Assume int4 + return struct.pack("!i", item) + elif isinstance(item, str): + return item.encode() + return item + + +def ensure_table(cur, tabledef, name="copy_in"): + cur.execute(f"drop table if exists {name}") + cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + def ensure_table(self): + cur = self.conn.cursor() + ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + def assert_data(self): + cur = self.conn.cursor() + cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == cur.fetchone() + + assert cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while True: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode() + m.update(block) + return m.hexdigest() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py new file mode 100644 index 0000000..59389dd --- /dev/null +++ b/tests/test_copy_async.py @@ -0,0 +1,892 @@ +import string +import hashlib +from io import BytesIO, StringIO +from random import choice, randrange +from itertools import cycle + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.copy import AsyncCopy +from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter +from psycopg.types import TypeInfo +from psycopg.adapt import PyFormat +from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 + +from .utils import alist, eur, gc_collect, gc_count +from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa +from .test_copy import sample_values, sample_records, sample_tabledef +from .test_copy import py_to_raw, special_chars + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("copy"), +] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_read(aconn, format): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + for row in want: + got = await copy.read() + assert got == row + assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE + + assert await copy.read() == b"" + assert await copy.read() == b"" + + assert await copy.read() == b"" + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_iter(aconn, format, row_factory): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + assert await alist(copy) == want + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_no_result(aconn, format, row_factory): + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"): + with pytest.raises(e.ProgrammingError): + await cur.fetchone() + + +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("typetype", ["names", "oids"]) +async def test_read_rows(aconn, format, typetype): + cur = aconn.cursor() + async with cur.copy( + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" + ) as copy: + copy.set_types(["int4", "text", "float8[]"]) + row = await copy.read_row() + assert (await copy.read_row()) is None + + assert row == (10, "hello", [0.0, 1.0]) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +async def test_rows(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types("int4 int4 text".split()) + rows = await alist(copy.rows()) + + assert rows == sample_records + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_set_custom_type(aconn, hstore): + command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout""" + cur = aconn.cursor() + + async with cur.copy(command) as copy: + rows = await alist(copy.rows()) + + assert rows == [('"a"=>"1", "b"=>"2"',)] + + register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur) + async with cur.copy(command) as copy: + copy.set_types(["hstore"]) + rows = await alist(copy.rows()) + + assert rows == [({"a": "1", "b": "2"},)] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_allchars(aconn, format): + cur = aconn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + await aconn.execute("set client_encoding to utf8") + rows = [] + query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format( + chars, sql.SQL(format.name) + ) + async with cur.copy(query) as copy: + copy.set_types(["text"]) + while True: + row = await copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", Format) +async def test_read_row_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + while True: + row = await copy.read_row() + if not row: + break + rows.append(row) + + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("format", Format) +async def test_rows_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = await alist(copy.rows()) + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", Format) +async def test_copy_out_badntypes(aconn, format, err): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + await copy.read_row() + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_bad_result(aconn): + await aconn.set_autocommit(True) + + cur = aconn.cursor() + + with pytest.raises(e.SyntaxError): + async with cur.copy("wat"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("reset timezone"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("copy (select 1) to stdout; select 1") as copy: + await alist(copy) + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1; copy (select 1) to stdout"): + pass + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin (format binary)") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +async def test_copy_big_size_block(aconn, pytype): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.parametrize("format", Format) +async def test_subclass_adapter(aconn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import ( # type: ignore[no-redef] + StrBinaryDumper as BaseDumper, + ) + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + aconn.adapters.register_dumper(str, MyStrDumper) + + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy( + f"copy copy_in (data) from stdin (format {format.name})" + ) as copy: + await copy.write_row(("hello",)) + + await cur.execute("select data from copy_in") + rec = await cur.fetchone() + assert rec[0] == "hellohello" + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_error_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_error_with_copy_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_copy_out_error_with_copy_not_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy (select generate_series(1, 1000000)) to stdout" + ) as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_server_error(aconn): + cur = aconn.cursor() + with pytest.raises(e.DivisionByZero): + async with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + async for block in copy: + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + await aconn.execute("set client_encoding to utf8") + async with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +async def test_copy_in_format(aconn): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy: + for i in range(1, 256): + await copy.write_row((i, chr(i))) + + file.seek(0) + rows = file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_file_writer(aconn, format, buffer): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy: + for record in sample_records: + await copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + +@pytest.mark.slow +async def test_copy_from_to(aconn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +async def test_copy_from_to_bytes(aconn, pytype): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(pytype(block.encode())) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +async def test_copy_from_insane_size(aconn): + # Trying to trigger a "would block" error + gen = DataGenerator( + aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + +async def test_copy_rowcount(aconn): + gen = DataGenerator(aconn, nrecs=3, srec=10) + await gen.ensure_table() + + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3) + async with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == 2 + + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + async with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == -1 + + +async def test_copy_query(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert cur._query.query == b"copy (select 1) to stdout" + assert not cur._query.params + await alist(copy) + + +async def test_cant_reenter(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + await alist(copy) + + with pytest.raises(TypeError): + async with copy: + await alist(copy) + + +async def test_str(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert "[ACTIVE]" in str(copy) + await alist(copy) + + assert "[INTRANS]" in str(copy) + + +async def test_description(aconn): + async with aconn.cursor() as cur: + async with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy: + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + await alist(copy.rows()) + + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_worker_life(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", + writer=AsyncQueuedLibpqWriter(cur), + ) as copy: + assert not copy.writer._worker + await copy.write(globals()[buffer]) + assert copy.writer._worker + + assert not copy.writer._worker + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_worker_error_propagated(aconn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = aconn.cursor() + await cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy wat from stdin", writer=AsyncQueuedLibpqWriter(cur) + ) as copy: + await copy.write("a,b") + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_connection_writer(aconn, format, buffer): + cur = aconn.cursor() + writer = AsyncLibpqWriter(cur) + + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) +async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + + stmt = sql.SQL( + "copy (select {} from {} order by id) to stdout (format {})" + ).format( + sql.SQL(", ").join(faker.fields_names), + faker.table_name, + sql.SQL(fmt.name), + ) + + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + + if method == "read": + while True: + tmp = await copy.read() + if not tmp: + break + elif method == "iter": + await alist(copy) + elif method == "row": + while True: + tmp = await copy.read_row() + if tmp is None: + break + elif method == "rows": + await alist(copy.rows()) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin (format {})").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL(fmt.name), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["row", "block", "binary"]) +async def test_copy_table_across(aconn_cls, dsn, faker, mode): + faker.choose_schema(ncols=20) + faker.make_records(20) + + connect = aconn_cls.connect + async with await connect(dsn) as conn1, await connect(dsn) as conn2: + faker.table_name = sql.Identifier("copy_src") + await conn1.execute(faker.drop_stmt) + await conn1.execute(faker.create_stmt) + await conn1.cursor().executemany(faker.insert_stmt, faker.records) + + faker.table_name = sql.Identifier("copy_tgt") + await conn2.execute(faker.drop_stmt) + await conn2.execute(faker.create_stmt) + + fmt = "(format binary)" if mode == "binary" else "" + async with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1: + async with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2: + if mode == "row": + async for row in copy1.rows(): + await copy2.write_row(row) + else: + async for data in copy1: + await copy2.write(data) + + cur = await conn2.execute(faker.select_stmt) + recs = await cur.fetchall() + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +async def ensure_table(cur, tabledef, name="copy_in"): + await cur.execute(f"drop table if exists {name}") + await cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + async def ensure_table(self): + cur = self.conn.cursor() + await ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + async def assert_data(self): + cur = self.conn.cursor() + await cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == await cur.fetchone() + + assert await cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while True: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode() + m.update(block) + return m.hexdigest() + + +class AsyncFileWriter(AsyncWriter): + def __init__(self, file): + self.file = file + + async def write(self, data): + self.file.write(data) diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 0000000..a667f4f --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,942 @@ +import pickle +import weakref +import datetime as dt +from typing import List, Union +from contextlib import closing + +import pytest + +import psycopg +from psycopg import pq, sql, rows +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins +from psycopg.rows import RowMaker + +from .utils import gc_collect, gc_count +from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision + + +def test_init(conn): + cur = psycopg.Cursor(conn) + cur.execute("select 1") + assert cur.fetchone() == (1,) + + conn.row_factory = rows.dict_row + cur = psycopg.Cursor(conn) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_init_factory(conn): + cur = psycopg.Cursor(conn, row_factory=rows.dict_row) + cur.execute("select 1 as a") + assert cur.fetchone() == {"a": 1} + + +def test_close(conn): + cur = conn.cursor() + assert not cur.closed + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.execute("select 'foo'") + + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + cur.fetchall() + + +def test_context(conn): + with conn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +def test_weakref(conn): + cur = conn.cursor() + w = weakref.ref(cur) + cur.close() + del cur + gc_collect() + assert w() is None + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_statusmessage(conn): + cur = conn.cursor() + assert cur.statusmessage is None + + cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + cur.execute("wat") + assert cur.statusmessage is None + + +def test_execute_many_results(conn): + cur = conn.cursor() + assert cur.nextset() is None + + rv = cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert cur.fetchall() == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.nextset() is None + + cur.close() + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + rv = cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +def test_execute_empty_query(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +def test_execute_type_change(conn): + # issue #112 + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.execute(sql, (1,)) + cur.execute(sql, (100_000,)) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +def test_executemany_type_change(conn): + conn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = conn.cursor() + cur.executemany(sql, [(1,), (100_000,)]) + cur.execute("select num from bug_112 order by num") + assert cur.fetchall() == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +def test_execute_copy(conn, query): + cur = conn.cursor() + cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + cur.execute(query) + + +def test_fetchone(conn): + cur = conn.cursor() + cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = cur.fetchone() + assert row == (1, "foo", None) + row = cur.fetchone() + assert row is None + + +def test_binary_cursor_execute(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None]) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +def test_execute_binary(conn): + cur = conn.cursor() + cur.execute("select %s, %s", [1, None], binary=True) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + cur.execute("select %s, %s", [1, None], binary=False) + assert cur.fetchone() == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +def test_query_encode(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("select '\u20ac'").fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +def test_query_badenc(conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + with pytest.raises(UnicodeEncodeError): + cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +def test_executemany_no_data(conn, execmany): + cur = conn.cursor() + cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +def test_executemany_returning(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_returning_discard(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + assert cur.nextset() is None + + +def test_executemany_no_result(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +def test_executemany_rowcount_no_hit(conn, execmany): + cur = conn.cursor() + cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)]) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_executemany_null_first(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table testmany (a bigint, b bigint)") + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +def test_rowcount(conn): + cur = conn.cursor() + + cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + cur.execute("show timezone") + assert cur.rowcount == 1 + + cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)") + assert cur.rowcount == 42 + + +def test_rownumber(conn): + cur = conn.cursor() + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +def test_rownumber_none(conn, query): + cur = conn.cursor() + cur.execute(query) + assert cur.rownumber is None + + +def test_rownumber_mixed(conn): + cur = conn.cursor() + cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert cur.fetchone() == (4,) + assert cur.rownumber == 1 + + +def test_iter(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + assert list(cur) == [(1,), (2,), (3,)] + + +def test_iter_stop(conn): + cur = conn.cursor() + cur.execute("select generate_series(1, 3)") + for rec in cur: + assert rec == (1,) + break + + for rec in cur: + assert rec == (2,) + break + + assert cur.fetchone() == (3,) + assert list(cur) == [] + + +def test_row_factory(conn): + cur = conn.cursor(row_factory=my_row_factory) + + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" + + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert cur.fetchall() == [["Yy", "Zz"]] + + cur.scroll(-1) + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} + + +def test_row_factory_none(conn): + cur = conn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + r = cur.execute("select 1 as a, 2 as b").fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +def test_bad_row_factory(conn): + def broken_factory(cur): + 1 / 0 + + cur = conn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = conn.cursor(row_factory=broken_maker) + cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + cur.fetchone() + + +def test_scroll(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + cur.scroll(0) + + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(-1) + assert cur.fetchone() == (8,) + cur.scroll(-2) + assert cur.fetchone() == (7,) + cur.scroll(2, mode="absolute") + assert cur.fetchone() == (2,) + + # on the boundary + cur.scroll(0, mode="absolute") + assert cur.fetchone() == (0,) + with pytest.raises(IndexError): + cur.scroll(-1, mode="absolute") + + cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(-1) + + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + with pytest.raises(IndexError): + cur.scroll(10, mode="absolute") + + cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + cur.scroll(1) + + with pytest.raises(ValueError): + cur.scroll(1, "wat") + + +def test_query_params_execute(conn): + cur = conn.cursor() + assert cur._query is None + + cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select $1, $2::text" + assert cur._query.params == [b"1", None] + + cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select $1::int" + assert cur._query.params == [b"wat"] + + +def test_query_params_executemany(conn): + cur = conn.cursor() + + cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select $1, $2" + assert cur._query.params == [b"3", b"4"] + + +def test_stream(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +def test_stream_sql(conn): + cur = conn.cursor() + recs = list( + cur.stream( + sql.SQL( + "select i, '2021-01-01'::date + i from generate_series(1, {}) as i" + ).format(2) + ) + ) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +def test_stream_row_factory(conn): + cur = conn.cursor(row_factory=rows.dict_row) + it = iter(cur.stream("select generate_series(1,2) as a")) + assert next(it)["a"] == 1 + cur.row_factory = rows.namedtuple_row + assert next(it).a == 2 + + +def test_stream_no_row(conn): + cur = conn.cursor() + recs = list(cur.stream("select generate_series(2,1) as a")) + assert recs == [] + + +@pytest.mark.crdb_skip("no col query") +def test_stream_no_col(conn): + cur = conn.cursor() + recs = list(cur.stream("select")) + assert recs == [()] + + +@pytest.mark.parametrize( + "query", + [ + "create table test_stream_badq ()", + "copy (select 1) to stdout", + "wat?", + ], +) +def test_stream_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream(query): + pass + + +def test_stream_error_tx(conn): + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_stream_error_notx(conn): + conn.autocommit = True + cur = conn.cursor() + with pytest.raises(psycopg.ProgrammingError): + for rec in cur.stream("wat"): + pass + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + + +def test_stream_error_python_to_consume(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + with closing(cur.stream("select generate_series(1, 10000)")) as gen: + for rec in gen: + 1 / 0 + assert conn.info.transaction_status in ( + conn.TransactionStatus.INTRANS, + conn.TransactionStatus.INERROR, + ) + + +def test_stream_error_python_consumed(conn): + cur = conn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + for rec in gen: + 1 / 0 + gen.close() + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + + +def test_stream_close(conn): + cur = conn.cursor() + with pytest.raises(psycopg.OperationalError): + for rec in cur.stream("select generate_series(1, 3)"): + if rec[0] == 1: + conn.close() + else: + assert False + + assert conn.closed + + +def test_stream_binary_cursor(conn): + cur = conn.cursor(binary=True) + recs = [] + for rec in cur.stream("select x::int4 from generate_series(1, 2) x"): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +def test_stream_execute_binary(conn): + cur = conn.cursor() + recs = [] + for rec in cur.stream("select x::int4 from generate_series(1, 2) x", binary=True): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +def test_stream_binary_cursor_text_override(conn): + cur = conn.cursor(binary=True) + recs = [] + for rec in cur.stream("select generate_series(1, 2)", binary=False): + recs.append(rec) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode() + + assert recs == [(1,), (2,)] + + +class TestColumn: + def test_description_attribs(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + assert len(curs.description) == 3 + for c in curs.description: + len(c) == 7 # DBAPI happy + for i, a in enumerate( + """ + name type_code display_size internal_size precision scale null_ok + """.split() + ): + assert c[i] == getattr(c, a) + + # Won't fill them up + assert c.null_ok is None + + c = curs.description[0] + assert c.name == "pi" + assert c.type_code == builtins["numeric"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision == 10 + assert c.scale == 2 + + c = curs.description[1] + assert c.name == "hi" + assert c.type_code == builtins["text"].oid + assert c.display_size is None + assert c.internal_size is None + assert c.precision is None + assert c.scale is None + + c = curs.description[2] + assert c.name == "now" + assert c.type_code == builtins["date"].oid + assert c.display_size is None + if is_crdb(conn): + assert c.internal_size == 16 + else: + assert c.internal_size == 4 + assert c.precision is None + assert c.scale is None + + def test_description_slice(self, conn): + curs = conn.cursor() + curs.execute("select 1::int as a") + curs.description[0][0:2] == ("a", 23) + + @pytest.mark.parametrize( + "type, precision, scale, dsize, isize", + [ + ("text", None, None, None, None), + ("varchar", None, None, None, None), + ("varchar(42)", None, None, 42, None), + ("int4", None, None, None, 4), + ("numeric", None, None, None, None), + ("numeric(10)", 10, 0, None, None), + ("numeric(10, 3)", 10, 3, None, None), + ("time", None, None, None, 8), + crdb_time_precision("time(4)", 4, None, None, 8), + crdb_time_precision("time(10)", 6, None, None, 8), + ], + ) + def test_details(self, conn, type, precision, scale, dsize, isize): + cur = conn.cursor() + cur.execute(f"select null::{type}") + col = cur.description[0] + repr(col) + assert col.precision == precision + assert col.scale == scale + assert col.display_size == dsize + assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled] + + @pytest.mark.crdb_skip("no col query") + def test_no_col_query(self, conn): + cur = conn.execute("select") + assert cur.description == [] + assert cur.fetchall() == [()] + + def test_description_closed_connection(self, conn): + # If we have reasons to break this test we will (e.g. we really need + # the connection). In #172 it fails just by accident. + cur = conn.execute("select 1::int4 as foo") + conn.close() + assert len(cur.description) == 1 + col = cur.description[0] + assert col.name == "foo" + assert col.type_code == 23 + + def test_name_not_a_name(self, conn): + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "foo-bar" """).fetchone() + assert res == "x" + assert cur.description[0].name == "foo-bar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_name_encode(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + cur = conn.cursor() + (res,) = cur.execute("""select 'x' as "\u20ac" """).fetchone() + assert res == "x" + assert cur.description[0].name == "\u20ac" + + +def test_str(conn): + cur = conn.cursor() + assert "psycopg.Cursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): + faker.format = fmt + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + def work(): + with conn_cls.connect(dsn) as conn, conn.transaction(force_rollback=True): + with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + with faker.find_insert_problem(conn): + cur.executemany(faker.insert_stmt, faker.records) + + cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + cur.fetchall() + elif fetch == "iter": + for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def my_row_factory( + cursor: Union[psycopg.Cursor[List[str]], psycopg.AsyncCursor[List[str]]] +) -> RowMaker[List[str]]: + if cursor.description is not None: + titles = [c.name for c in cursor.description] + + def mkrow(values): + return [f"{value.upper()}{title}" for title, value in zip(titles, values)] + + return mkrow + else: + return rows.no_result diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py new file mode 100644 index 0000000..ac3fdeb --- /dev/null +++ b/tests/test_cursor_async.py @@ -0,0 +1,802 @@ +import pytest +import weakref +import datetime as dt +from typing import List + +import psycopg +from psycopg import pq, sql, rows +from psycopg.adapt import PyFormat + +from .utils import gc_collect, gc_count +from .test_cursor import my_row_factory +from .test_cursor import execmany, _execmany # noqa: F401 +from .fix_crdb import crdb_encoding + +execmany = execmany # avoid F811 underneath +pytestmark = pytest.mark.asyncio + + +async def test_init(aconn): + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1") + assert (await cur.fetchone()) == (1,) + + aconn.row_factory = rows.dict_row + cur = psycopg.AsyncCursor(aconn) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_init_factory(aconn): + cur = psycopg.AsyncCursor(aconn, row_factory=rows.dict_row) + await cur.execute("select 1 as a") + assert (await cur.fetchone()) == {"a": 1} + + +async def test_close(aconn): + cur = aconn.cursor() + assert not cur.closed + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.execute("select 'foo'") + + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor() + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(psycopg.InterfaceError): + await cur.fetchall() + + +async def test_context(aconn): + async with aconn.cursor() as cur: + assert not cur.closed + + assert cur.closed + + +@pytest.mark.slow +async def test_weakref(aconn): + cur = aconn.cursor() + w = weakref.ref(cur) + await cur.close() + del cur + gc_collect() + assert w() is None + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_statusmessage(aconn): + cur = aconn.cursor() + assert cur.statusmessage is None + + await cur.execute("select generate_series(1, 10)") + assert cur.statusmessage == "SELECT 10" + + await cur.execute("create table statusmessage ()") + assert cur.statusmessage == "CREATE TABLE" + + with pytest.raises(psycopg.ProgrammingError): + await cur.execute("wat") + assert cur.statusmessage is None + + +async def test_execute_many_results(aconn): + cur = aconn.cursor() + assert cur.nextset() is None + + rv = await cur.execute("select 'foo'; select generate_series(1,3)") + assert rv is cur + assert (await cur.fetchall()) == [("foo",)] + assert cur.rowcount == 1 + assert cur.nextset() + assert (await cur.fetchall()) == [(1,), (2,), (3,)] + assert cur.rowcount == 3 + assert cur.nextset() is None + + await cur.close() + assert cur.nextset() is None + + +async def test_execute_sequence(aconn): + cur = aconn.cursor() + rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert rv is cur + assert len(cur._results) == 1 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.pgresult.get_value(0, 1) == b"foo" + assert cur.pgresult.get_value(0, 2) is None + assert cur.nextset() is None + + +@pytest.mark.parametrize("query", ["", " ", ";"]) +async def test_execute_empty_query(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.pgresult.status == cur.ExecStatus.EMPTY_QUERY + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + + +async def test_execute_type_change(aconn): + # issue #112 + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.execute(sql, (1,)) + await cur.execute(sql, (100_000,)) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +async def test_executemany_type_change(aconn): + await aconn.execute("create table bug_112 (num integer)") + sql = "insert into bug_112 (num) values (%s)" + cur = aconn.cursor() + await cur.executemany(sql, [(1,), (100_000,)]) + await cur.execute("select num from bug_112 order by num") + assert (await cur.fetchall()) == [(1,), (100_000,)] + + +@pytest.mark.parametrize( + "query", ["copy testcopy from stdin", "copy testcopy to stdout"] +) +async def test_execute_copy(aconn, query): + cur = aconn.cursor() + await cur.execute("create table testcopy (id int)") + with pytest.raises(psycopg.ProgrammingError): + await cur.execute(query) + + +async def test_fetchone(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) + assert cur.pgresult.fformat(0) == 0 + + row = await cur.fetchone() + assert row == (1, "foo", None) + row = await cur.fetchone() + assert row is None + + +async def test_binary_cursor_execute(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None]) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +async def test_execute_binary(aconn): + cur = aconn.cursor() + await cur.execute("select %s, %s", [1, None], binary=True) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x01" + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + await cur.execute("select %s, %s", [1, None], binary=False) + assert (await cur.fetchone()) == (1, None) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + + +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +async def test_query_encode(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + await cur.execute("select '\u20ac'") + (res,) = await cur.fetchone() + assert res == "\u20ac" + + +@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")]) +async def test_query_badenc(aconn, encoding): + await aconn.execute(f"set client_encoding to {encoding}") + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + await cur.execute("select '\u20ac'") + + +async def test_executemany(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(10, "hello"), (20, "world")] + + +async def test_executemany_name(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + await cur.execute("select num, data from execmany order by 1") + rv = await cur.fetchall() + assert rv == [(11, "hello"), (21, "world")] + + +async def test_executemany_no_data(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("insert into execmany(num, data) values (%s, %s)", []) + assert cur.rowcount == 0 + + +async def test_executemany_rowcount(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +async def test_executemany_returning(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_returning_discard(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s) returning num", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + assert cur.nextset() is None + + +async def test_executemany_no_result(aconn, execmany): + cur = aconn.cursor() + await cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.statusmessage.startswith("INSERT") + with pytest.raises(psycopg.ProgrammingError): + await cur.fetchone() + pgresult = cur.pgresult + assert cur.nextset() + assert cur.statusmessage.startswith("INSERT") + assert pgresult is not cur.pgresult + assert cur.nextset() is None + + +async def test_executemany_rowcount_no_hit(aconn, execmany): + cur = aconn.cursor() + await cur.executemany("delete from execmany where id = %s", [(-1,), (-2,)]) + assert cur.rowcount == 0 + await cur.executemany("delete from execmany where id = %s", []) + assert cur.rowcount == 0 + await cur.executemany( + "delete from execmany where id = %s returning num", [(-1,), (-2,)] + ) + assert cur.rowcount == 0 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +async def test_executemany_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.DatabaseError): + await cur.executemany(query, [(10, "hello"), (20, "world")]) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +async def test_executemany_null_first(aconn, fmt_in): + cur = aconn.cursor() + await cur.execute("create table testmany (a bigint, b bigint)") + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, None], [3, 4]], + ) + with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)): + await cur.executemany( + f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})", + [[1, ""], [3, 4]], + ) + + +async def test_rowcount(aconn): + cur = aconn.cursor() + + await cur.execute("select 1 from generate_series(1, 0)") + assert cur.rowcount == 0 + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rowcount == 42 + + await cur.execute("show timezone") + assert cur.rowcount == 1 + + await cur.execute("create table test_rowcount_notuples (id int primary key)") + assert cur.rowcount == -1 + + await cur.execute( + "insert into test_rowcount_notuples select generate_series(1, 42)" + ) + assert cur.rowcount == 42 + + +async def test_rownumber(aconn): + cur = aconn.cursor() + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + rns: List[int] = [] + async for i in cur: + assert cur.rownumber + rns.append(cur.rownumber) + if len(rns) >= 3: + break + assert rns == [13, 14, 15] + assert len(await cur.fetchall()) == 42 - rns[-1] + assert cur.rownumber == 42 + + +@pytest.mark.parametrize("query", ["", "set timezone to utc"]) +async def test_rownumber_none(aconn, query): + cur = aconn.cursor() + await cur.execute(query) + assert cur.rownumber is None + + +async def test_rownumber_mixed(aconn): + cur = aconn.cursor() + await cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) + assert cur.rownumber == 0 + assert await cur.fetchone() == (1,) + assert cur.rownumber == 1 + assert await cur.fetchone() == (2,) + assert cur.rownumber == 2 + cur.nextset() + assert cur.rownumber is None + cur.nextset() + assert cur.rownumber == 0 + assert await cur.fetchone() == (4,) + assert cur.rownumber == 1 + + +async def test_iter(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + res = [] + async for rec in cur: + res.append(rec) + assert res == [(1,), (2,), (3,)] + + +async def test_iter_stop(aconn): + cur = aconn.cursor() + await cur.execute("select generate_series(1, 3)") + async for rec in cur: + assert rec == (1,) + break + + async for rec in cur: + assert rec == (2,) + break + + assert (await cur.fetchone()) == (3,) + async for rec in cur: + assert False + + +async def test_row_factory(aconn): + cur = aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + await cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert await cur.fetchall() == [["Xx"]] + assert cur.nextset() + assert await cur.fetchall() == [["Yy", "Zz"]] + + await cur.scroll(-1) + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} + + +async def test_row_factory_none(aconn): + cur = aconn.cursor(row_factory=None) + assert cur.row_factory is rows.tuple_row + await cur.execute("select 1 as a, 2 as b") + r = await cur.fetchone() + assert type(r) is tuple + assert r == (1, 2) + + +async def test_bad_row_factory(aconn): + def broken_factory(cur): + 1 / 0 + + cur = aconn.cursor(row_factory=broken_factory) + with pytest.raises(ZeroDivisionError): + await cur.execute("select 1") + + def broken_maker(cur): + def make_row(seq): + 1 / 0 + + return make_row + + cur = aconn.cursor(row_factory=broken_maker) + await cur.execute("select 1") + with pytest.raises(ZeroDivisionError): + await cur.fetchone() + + +async def test_scroll(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + await cur.scroll(0) + + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(-1) + assert await cur.fetchone() == (8,) + await cur.scroll(-2) + assert await cur.fetchone() == (7,) + await cur.scroll(2, mode="absolute") + assert await cur.fetchone() == (2,) + + # on the boundary + await cur.scroll(0, mode="absolute") + assert await cur.fetchone() == (0,) + with pytest.raises(IndexError): + await cur.scroll(-1, mode="absolute") + + await cur.scroll(0, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(-1) + + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + with pytest.raises(IndexError): + await cur.scroll(10, mode="absolute") + + await cur.scroll(9, mode="absolute") + with pytest.raises(IndexError): + await cur.scroll(1) + + with pytest.raises(ValueError): + await cur.scroll(1, "wat") + + +async def test_query_params_execute(aconn): + cur = aconn.cursor() + assert cur._query is None + + await cur.execute("select %t, %s::text", [1, None]) + assert cur._query is not None + assert cur._query.query == b"select $1, $2::text" + assert cur._query.params == [b"1", None] + + await cur.execute("select 1") + assert cur._query.query == b"select 1" + assert not cur._query.params + + with pytest.raises(psycopg.DataError): + await cur.execute("select %t::int", ["wat"]) + + assert cur._query.query == b"select $1::int" + assert cur._query.params == [b"wat"] + + +async def test_query_params_executemany(aconn): + cur = aconn.cursor() + + await cur.executemany("select %t, %t", [[1, 2], [3, 4]]) + assert cur._query.query == b"select $1, $2" + assert cur._query.params == [b"3", b"4"] + + +async def test_stream(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select i, '2021-01-01'::date + i from generate_series(1, %s) as i", + [2], + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_stream_sql(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + sql.SQL( + "select i, '2021-01-01'::date + i from generate_series(1, {}) as i" + ).format(2) + ): + recs.append(rec) + + assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))] + + +async def test_stream_row_factory(aconn): + cur = aconn.cursor(row_factory=rows.dict_row) + ait = cur.stream("select generate_series(1,2) as a") + assert (await ait.__anext__())["a"] == 1 + cur.row_factory = rows.namedtuple_row + assert (await ait.__anext__()).a == 2 + + +async def test_stream_no_row(aconn): + cur = aconn.cursor() + recs = [rec async for rec in cur.stream("select generate_series(2,1) as a")] + assert recs == [] + + +@pytest.mark.crdb_skip("no col query") +async def test_stream_no_col(aconn): + cur = aconn.cursor() + recs = [rec async for rec in cur.stream("select")] + assert recs == [()] + + +@pytest.mark.parametrize( + "query", + [ + "create table test_stream_badq ()", + "copy (select 1) to stdout", + "wat?", + ], +) +async def test_stream_badquery(aconn, query): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream(query): + pass + + +async def test_stream_error_tx(aconn): + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_stream_error_notx(aconn): + await aconn.set_autocommit(True) + cur = aconn.cursor() + with pytest.raises(psycopg.ProgrammingError): + async for rec in cur.stream("wat"): + pass + assert aconn.info.transaction_status == aconn.TransactionStatus.IDLE + + +async def test_stream_error_python_to_consume(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select generate_series(1, 10000)") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status in ( + aconn.TransactionStatus.INTRANS, + aconn.TransactionStatus.INERROR, + ) + + +async def test_stream_error_python_consumed(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + gen = cur.stream("select 1") + async for rec in gen: + 1 / 0 + + await gen.aclose() + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_stream_close(aconn): + await aconn.set_autocommit(True) + cur = aconn.cursor() + with pytest.raises(psycopg.OperationalError): + async for rec in cur.stream("select generate_series(1, 3)"): + if rec[0] == 1: + await aconn.close() + else: + assert False + + assert aconn.closed + + +async def test_stream_binary_cursor(aconn): + cur = aconn.cursor(binary=True) + recs = [] + async for rec in cur.stream("select x::int4 from generate_series(1, 2) x"): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +async def test_stream_execute_binary(aconn): + cur = aconn.cursor() + recs = [] + async for rec in cur.stream( + "select x::int4 from generate_series(1, 2) x", binary=True + ): + recs.append(rec) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]]) + + assert recs == [(1,), (2,)] + + +async def test_stream_binary_cursor_text_override(aconn): + cur = aconn.cursor(binary=True) + recs = [] + async for rec in cur.stream("select generate_series(1, 2)", binary=False): + recs.append(rec) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode() + + assert recs == [(1,), (2,)] + + +async def test_str(aconn): + cur = aconn.cursor() + assert "psycopg.AsyncCursor" in str(cur) + assert "[IDLE]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" in str(cur) + await cur.execute("select 1") + assert "[INTRANS]" in str(cur) + assert "[TUPLES_OK]" in str(cur) + assert "[closed]" not in str(cur) + assert "[no result]" not in str(cur) + await cur.close() + assert "[closed]" in str(cur) + assert "[INTRANS]" in str(cur) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"]) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory): + faker.format = fmt + faker.choose_schema(ncols=5) + faker.make_records(10) + row_factory = getattr(rows, row_factory) + + async def work(): + async with await aconn_cls.connect(dsn) as conn, conn.transaction( + force_rollback=True + ): + async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + await cur.execute(faker.select_stmt) + + if fetch == "one": + while True: + tmp = await cur.fetchone() + if tmp is None: + break + elif fetch == "many": + while True: + tmp = await cur.fetchmany(3) + if not tmp: + break + elif fetch == "all": + await cur.fetchall() + elif fetch == "iter": + async for rec in cur: + pass + + n = [] + gc_collect() + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 0000000..f50092f --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,27 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +pytestmark = [pytest.mark.dns] + + +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_warning(recwarn): + import_dnspython() + conninfo = "dbname=foo" + params = conninfo_to_dict(conninfo) + params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined] + params + ) + assert conninfo_to_dict(conninfo) == params + assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message) + + +def import_dnspython(): + try: + import dns.rdtypes.IN.A # noqa: F401 + except ImportError: + pytest.skip("dnspython package not available") + + import psycopg._dns # noqa: F401 diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py new file mode 100644 index 0000000..15b3706 --- /dev/null +++ b/tests/test_dns_srv.py @@ -0,0 +1,149 @@ +from typing import List, Union + +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +from .test_dns import import_dnspython + +pytestmark = [pytest.mark.dns] + +samples_ok = [ + ("", "", None), + ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None), + ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}), + ( + "host=foo.com,_pg._tcp.foo.com", + "host=foo.com,db1.example.com port=,5432", + None, + ), + ( + "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com", + "host=foo.com,db1.example.com port=,5432", + None, + ), + ( + "host=_pg._tcp.bar.com", + "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com" + " port=5432,5432,5433,5432", + None, + ), + ( + "host=service.foo.com port=srv", + "host=service.example.com port=15432", + None, + ), + # No resolution + ( + "host=_pg._tcp.foo.com hostaddr=1.1.1.1", + "host=_pg._tcp.foo.com hostaddr=1.1.1.1", + None, + ), +] + + +@pytest.mark.flakey("random weight order, might cause wrong order") +@pytest.mark.parametrize("conninfo, want, env", samples_ok) +def test_srv(conninfo, want, env, fake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = psycopg._dns.resolve_srv(params) # type: ignore[attr-defined] + assert conninfo_to_dict(want) == params + + +@pytest.mark.asyncio +@pytest.mark.parametrize("conninfo, want, env", samples_ok) +async def test_srv_async(conninfo, want, env, afake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = await ( + psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined] + ) + assert conninfo_to_dict(want) == params + + +samples_bad = [ + ("host=_pg._tcp.dot.com", None), + ("host=_pg._tcp.foo.com port=1,2", None), +] + + +@pytest.mark.parametrize("conninfo, env", samples_bad) +def test_srv_bad(conninfo, env, fake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.OperationalError): + psycopg._dns.resolve_srv(params) # type: ignore[attr-defined] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("conninfo, env", samples_bad) +async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.OperationalError): + await psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined] + + +@pytest.fixture +def fake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + monkeypatch.setattr( + psycopg._dns.resolver, # type: ignore[attr-defined] + "resolve", + f, + ) + + +@pytest.fixture +def afake_srv(monkeypatch): + f = get_fake_srv_function(monkeypatch) + + async def af(qname, rdtype): + return f(qname, rdtype) + + monkeypatch.setattr( + psycopg._dns.async_resolver, # type: ignore[attr-defined] + "resolve", + af, + ) + + +def get_fake_srv_function(monkeypatch): + import_dnspython() + + from dns.rdtypes.IN.A import A + from dns.rdtypes.IN.SRV import SRV + from dns.exception import DNSException + + fake_hosts = { + ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."], + ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."], + ("_pg._tcp.bar.com", "SRV"): [ + "1 0 5432 db2.example.com.", + "1 255 5433 db3.example.com.", + "0 0 5432 db1.example.com.", + "1 65535 5432 db4.example.com.", + ], + ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."], + } + + def fake_srv_(qname, rdtype): + try: + ans = fake_hosts[qname, rdtype] + except KeyError: + raise DNSException(f"unknown test host: {qname} {rdtype}") + rv: List[Union[A, SRV]] = [] + + if rdtype == "A": + for entry in ans: + rv.append(A("IN", "A", entry)) + else: + for entry in ans: + pri, w, port, target = entry.split() + rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target)) + + return rv + + return fake_srv_ diff --git a/tests/test_encodings.py b/tests/test_encodings.py new file mode 100644 index 0000000..113f0e3 --- /dev/null +++ b/tests/test_encodings.py @@ -0,0 +1,57 @@ +import codecs +import pytest + +import psycopg +from psycopg import _encodings as encodings + + +def test_names_normalised(): + for name in encodings._py_codecs.values(): + assert codecs.lookup(name).name == name + + +@pytest.mark.parametrize( + "pyenc, pgenc", + [ + ("ascii", "SQL_ASCII"), + ("utf8", "UTF8"), + ("utf-8", "UTF8"), + ("uTf-8", "UTF8"), + ("latin9", "LATIN9"), + ("iso8859-15", "LATIN9"), + ], +) +def test_py2pg(pyenc, pgenc): + assert encodings.py2pgenc(pyenc) == pgenc.encode() + + +@pytest.mark.parametrize( + "pyenc, pgenc", + [ + ("ascii", "SQL_ASCII"), + ("utf-8", "UTF8"), + ("iso8859-15", "LATIN9"), + ], +) +def test_pg2py(pyenc, pgenc): + assert encodings.pg2pyenc(pgenc.encode()) == pyenc + + +@pytest.mark.parametrize("pgenc", ["MULE_INTERNAL", "EUC_TW"]) +def test_pg2py_missing(pgenc): + with pytest.raises(psycopg.NotSupportedError): + encodings.pg2pyenc(pgenc.encode()) + + +@pytest.mark.parametrize( + "conninfo, pyenc", + [ + ("", "utf-8"), + ("user=foo, dbname=bar", "utf-8"), + ("user=foo, dbname=bar, client_encoding=EUC_JP", "euc_jp"), + ("user=foo, dbname=bar, client_encoding=euc-jp", "euc_jp"), + ("user=foo, dbname=bar, client_encoding=WAT", "utf-8"), + ], +) +def test_conninfo_encoding(conninfo, pyenc): + assert encodings.conninfo_encoding(conninfo) == pyenc diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..23ad314 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,309 @@ +import pickle +from typing import List +from weakref import ref + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +from .utils import eur, gc_collect +from .fix_crdb import is_crdb + + +@pytest.mark.crdb_skip("severity_nonlocalized") +def test_error_diag(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + exc = excinfo.value + diag = exc.diag + assert diag.sqlstate == "42P01" + assert diag.severity_nonlocalized == "ERROR" + + +def test_diag_all_attrs(pgconn): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + for d in pq.DiagnosticField: + val = getattr(diag, d.name.lower()) + assert val is None or isinstance(val, str) + + +def test_diag_right_attr(pgconn, monkeypatch): + res = pgconn.make_empty_result(pq.ExecStatus.NONFATAL_ERROR) + diag = e.Diagnostic(res) + + to_check: pq.DiagnosticField + checked: List[pq.DiagnosticField] = [] + + def check_val(self, v): + nonlocal to_check + assert to_check == v + checked.append(v) + return None + + monkeypatch.setattr(e.Diagnostic, "_error_message", check_val) + + for to_check in pq.DiagnosticField: + getattr(diag, to_check.name.lower()) + + assert len(checked) == len(pq.DiagnosticField) + + +def test_diag_attr_values(conn): + if is_crdb(conn): + conn.execute("set experimental_enable_temp_tables = 'on'") + conn.execute( + """ + create temp table test_exc ( + data int constraint chk_eq1 check (data = 1) + )""" + ) + with pytest.raises(e.Error) as exc: + conn.execute("insert into test_exc values(2)") + diag = exc.value.diag + assert diag.sqlstate == "23514" + assert diag.constraint_name == "chk_eq1" + if not is_crdb(conn): + assert diag.table_name == "test_exc" + assert diag.schema_name and diag.schema_name[:7] == "pg_temp" + assert diag.severity_nonlocalized == "ERROR" + + +@pytest.mark.crdb_skip("do") +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_diag_encoding(conn, enc): + msgs = [] + conn.pgconn.exec_(b"set client_min_messages to notice") + conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary)) + conn.execute(f"set client_encoding to {enc}") + cur = conn.cursor() + cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql") + assert msgs == [f"hello {eur}"] + + +@pytest.mark.crdb_skip("do") +@pytest.mark.parametrize("enc", ["utf8", "latin9"]) +def test_error_encoding(conn, enc): + with conn.transaction(): + conn.execute(f"set client_encoding to {enc}") + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute( + """ + do $$begin + execute format('insert into "%s" values (1)', chr(8364)); + end$$ language plpgsql; + """ + ) + + diag = excinfo.value.diag + assert diag.message_primary and f'"{eur}"' in diag.message_primary + assert diag.sqlstate == "42P01" + + +def test_exception_class(conn): + cur = conn.cursor() + + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select * from nonexist") + + assert isinstance(excinfo.value, e.UndefinedTable) + assert isinstance(excinfo.value, conn.ProgrammingError) + + +def test_exception_class_fallback(conn): + cur = conn.cursor() + + x = e._sqlcodes.pop("42P01") + try: + with pytest.raises(e.Error) as excinfo: + cur.execute("select * from nonexist") + finally: + e._sqlcodes["42P01"] = x + + assert type(excinfo.value) is conn.ProgrammingError + + +def test_lookup(): + assert e.lookup("42P01") is e.UndefinedTable + assert e.lookup("42p01") is e.UndefinedTable + assert e.lookup("UNDEFINED_TABLE") is e.UndefinedTable + assert e.lookup("undefined_table") is e.UndefinedTable + + with pytest.raises(KeyError): + e.lookup("XXXXX") + + +def test_error_sqlstate(): + assert e.Error.sqlstate is None + assert e.ProgrammingError.sqlstate is None + assert e.UndefinedTable.sqlstate == "42P01" + + +def test_error_pickle(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert isinstance(exc, e.UndefinedTable) + assert exc.diag.sqlstate == "42P01" + + +def test_diag_pickle(conn): + cur = conn.cursor() + with pytest.raises(e.DatabaseError) as excinfo: + cur.execute("select 1 from wat") + + diag1 = excinfo.value.diag + diag2 = pickle.loads(pickle.dumps(diag1)) + + assert isinstance(diag2, type(diag1)) + for f in pq.DiagnosticField: + assert getattr(diag1, f.name.lower()) == getattr(diag2, f.name.lower()) + + assert diag2.sqlstate == "42P01" + + +@pytest.mark.slow +def test_diag_survives_cursor(conn): + cur = conn.cursor() + with pytest.raises(e.Error) as exc: + cur.execute("select * from nosuchtable") + + diag = exc.value.diag + del exc + w = ref(cur) + del cur + gc_collect() + assert w() is None + assert diag.sqlstate == "42P01" + + +def test_diag_independent(conn): + conn.autocommit = True + cur = conn.cursor() + + with pytest.raises(e.Error) as exc1: + cur.execute("l'acqua e' poca e 'a papera nun galleggia") + + with pytest.raises(e.Error) as exc2: + cur.execute("select level from water where ducks > 1") + + assert exc1.value.diag.sqlstate == "42601" + assert exc2.value.diag.sqlstate == "42P01" + + +@pytest.mark.crdb_skip("deferrable") +def test_diag_from_commit(conn): + cur = conn.cursor() + cur.execute( + """ + create temp table test_deferred ( + data int primary key, + ref int references test_deferred (data) + deferrable initially deferred) + """ + ) + cur.execute("insert into test_deferred values (1,2)") + with pytest.raises(e.Error) as exc: + conn.commit() + + assert exc.value.diag.sqlstate == "23503" + + +@pytest.mark.asyncio +@pytest.mark.crdb_skip("deferrable") +async def test_diag_from_commit_async(aconn): + cur = aconn.cursor() + await cur.execute( + """ + create temp table test_deferred ( + data int primary key, + ref int references test_deferred (data) + deferrable initially deferred) + """ + ) + await cur.execute("insert into test_deferred values (1,2)") + with pytest.raises(e.Error) as exc: + await aconn.commit() + + assert exc.value.diag.sqlstate == "23503" + + +def test_query_context(conn): + with pytest.raises(e.Error) as exc: + conn.execute("select * from wat") + + s = str(exc.value) + if not is_crdb(conn): + assert "from wat" in s, s + assert exc.value.diag.message_primary + assert exc.value.diag.message_primary in s + assert "ERROR" not in s + assert not s.endswith("\n") + + +@pytest.mark.crdb_skip("do") +def test_unknown_sqlstate(conn): + code = "PXX99" + with pytest.raises(KeyError): + e.lookup(code) + + with pytest.raises(e.ProgrammingError) as excinfo: + conn.execute( + f""" + do $$begin + raise exception 'made up code' using errcode = '{code}'; + end$$ language plpgsql + """ + ) + exc = excinfo.value + assert exc.diag.sqlstate == code + assert exc.sqlstate == code + # Survives pickling too + pexc = pickle.loads(pickle.dumps(exc)) + assert pexc.sqlstate == code + + +def test_pgconn_error(conn_cls): + with pytest.raises(psycopg.OperationalError) as excinfo: + conn_cls.connect("dbname=nosuchdb") + + exc = excinfo.value + assert exc.pgconn + assert exc.pgconn.db == b"nosuchdb" + + +def test_pgconn_error_pickle(conn_cls): + with pytest.raises(psycopg.OperationalError) as excinfo: + conn_cls.connect("dbname=nosuchdb") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert exc.pgconn is None + + +def test_pgresult(conn): + with pytest.raises(e.DatabaseError) as excinfo: + conn.execute("select 1 from wat") + + exc = excinfo.value + assert exc.pgresult + assert exc.pgresult.error_field(pq.DiagnosticField.SQLSTATE) == b"42P01" + + +def test_pgresult_pickle(conn): + with pytest.raises(e.DatabaseError) as excinfo: + conn.execute("select 1 from wat") + + exc = pickle.loads(pickle.dumps(excinfo.value)) + assert exc.pgresult is None + assert exc.diag.sqlstate == "42P01" + + +def test_blank_sqlstate(conn): + assert e.get_base_exception("") is e.DatabaseError diff --git a/tests/test_generators.py b/tests/test_generators.py new file mode 100644 index 0000000..8aba73f --- /dev/null +++ b/tests/test_generators.py @@ -0,0 +1,156 @@ +from collections import deque +from functools import partial +from typing import List + +import pytest + +import psycopg +from psycopg import waiting +from psycopg import pq + + +@pytest.fixture +def pipeline(pgconn): + nb, pgconn.nonblocking = pgconn.nonblocking, True + assert pgconn.nonblocking + pgconn.enter_pipeline_mode() + yield + if pgconn.pipeline_status: + pgconn.exit_pipeline_mode() + pgconn.nonblocking = nb + + +def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses): + actual_statuses: List[pq.ExecStatus] = [] + while len(actual_statuses) != len(expected_statuses): + if commands: + gen = generators.pipeline_communicate(pgconn, commands) + results = waiting.wait(gen, pgconn.socket) + for (result,) in results: + actual_statuses.append(result.status) + else: + gen = generators.fetch_many(pgconn) + results = waiting.wait(gen, pgconn.socket) + for result in results: + actual_statuses.append(result.status) + + assert actual_statuses == expected_statuses + + +@pytest.mark.pipeline +def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_params, b"select 1", None), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, b"select 2", None), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.mark.pipeline +def test_pipeline_communicate_no_sync(pgconn, pipeline, generators): + numqueries = 10 + commands = deque( + [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)] + * numqueries + + [pgconn.send_flush_request] + ) + expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.fixture +def pipeline_demo(pgconn): + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + yield "pg_pipeline" + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + +# TODOCRDB: 1 doesn't get rolled back. Open a ticket? +@pytest.mark.pipeline +@pytest.mark.crdb("skip", reason="pipeline aborted") +def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators): + insert_sql = b"insert into pg_pipeline(itemno) values ($1)" + commands = deque( + [ + partial(pgconn.send_query_params, insert_sql, [b"1"]), + partial(pgconn.send_query_params, b"select no_such_function(1)", None), + partial(pgconn.send_query_params, insert_sql, [b"2"]), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, insert_sql, [b"3"]), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + pgconn.exit_pipeline_mode() + res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno") + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"3" + + +@pytest.fixture +def pipeline_uniqviol(pgconn): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol(" + b" id bigint primary key, idata bigint)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_(b"BEGIN") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.prepare( + b"insertion", + b"insert into pg_pipeline_uniqviol values ($1, $2) returning id", + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + return "pg_pipeline_uniqviol" + + +def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]), + partial(pgconn.send_query_params, b"commit", None), + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 0000000..794ef0f --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,57 @@ +import pytest + +from psycopg._cmodule import _psycopg + + +@pytest.mark.parametrize( + "args, kwargs, want_conninfo", + [ + ((), {}, ""), + (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"), + ((), {"port": 15432}, "port=15432"), + ((), {"user": "foo", "dbname": None}, "user=foo"), + ], +) +def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): + # Check the main args passing from psycopg.connect to the conn generator + # Details of the params manipulation are in test_conninfo. + import psycopg.connection + + orig_connect = psycopg.connection.connect # type: ignore + + got_conninfo = None + + def mock_connect(conninfo): + nonlocal got_conninfo + got_conninfo = conninfo + return orig_connect(dsn) + + monkeypatch.setattr(psycopg.connection, "connect", mock_connect) + + conn = psycopg.connect(*args, **kwargs) + assert got_conninfo == want_conninfo + conn.close() + + +def test_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg import __version__ +assert __version__ +""" + ) + assert not cp.stdout + + +@pytest.mark.skipif(_psycopg is None, reason="C module test") +def test_version_c(mypy): + # can be psycopg_c, psycopg_binary + cpackage = _psycopg.__name__.split(".")[0] + + cp = mypy.run_on_source( + f"""\ +from {cpackage} import __version__ +assert __version__ +""" + ) + assert not cp.stdout diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..56fe598 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,577 @@ +import logging +import concurrent.futures +from typing import Any +from operator import attrgetter +from itertools import groupby + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +pytestmark = [ + pytest.mark.pipeline, + pytest.mark.skipif("not psycopg.Pipeline.is_supported()"), +] + +pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted") + + +def test_repr(conn): + with conn.pipeline() as p: + assert "psycopg.Pipeline" in repr(p) + assert "[IDLE, pipeline=ON]" in repr(p) + + conn.close() + assert "[BAD]" in repr(p) + + +def test_connection_closed(conn): + conn.close() + with pytest.raises(e.OperationalError): + with conn.pipeline(): + pass + + +def test_pipeline_status(conn: psycopg.Connection[Any]) -> None: + assert conn._pipeline is None + with conn.pipeline() as p: + assert conn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert p.status == pq.PipelineStatus.OFF + assert not conn._pipeline + + +def test_pipeline_reenter(conn: psycopg.Connection[Any]) -> None: + with conn.pipeline() as p1: + with conn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON + assert conn._pipeline is None + assert p1.status == pq.PipelineStatus.OFF + + +def test_pipeline_broken_conn_exit(conn: psycopg.Connection[Any]) -> None: + with pytest.raises(e.OperationalError): + with conn.pipeline(): + conn.execute("select 1") + conn.close() + closed = True + + assert closed + + +def test_pipeline_exit_error_noclobber(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + with conn.pipeline(): + conn.close() + 1 / 0 + + assert len(caplog.records) == 1 + + +def test_pipeline_exit_error_noclobber_nested(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + with conn.pipeline(): + with conn.pipeline(): + conn.close() + 1 / 0 + + assert len(caplog.records) == 2 + + +def test_pipeline_exit_sync_trace(conn, trace): + t = trace.trace(conn) + with conn.pipeline(): + pass + conn.close() + assert len([i for i in t if i.type == "Sync"]) == 1 + + +def test_pipeline_nested_sync_trace(conn, trace): + t = trace.trace(conn) + with conn.pipeline(): + with conn.pipeline(): + pass + conn.close() + assert len([i for i in t if i.type == "Sync"]) == 2 + + +def test_cursor_stream(conn): + with conn.pipeline(), conn.cursor() as cur: + with pytest.raises(psycopg.ProgrammingError): + cur.stream("select 1").__next__() + + +def test_server_cursor(conn): + with conn.cursor(name="pipeline") as cur, conn.pipeline(): + with pytest.raises(psycopg.NotSupportedError): + cur.execute("select 1") + + +def test_cannot_insert_multiple_commands(conn): + with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)): + with conn.pipeline(): + conn.execute("select 1; select 2") + + +def test_copy(conn): + with conn.pipeline(): + cur = conn.cursor() + with pytest.raises(e.NotSupportedError): + with cur.copy("copy (select 1) to stdout"): + pass + + +def test_pipeline_processed_at_exit(conn): + with conn.cursor() as cur: + with conn.pipeline() as p: + cur.execute("select 1") + + assert len(p.result_queue) == 1 + + assert cur.fetchone() == (1,) + + +def test_pipeline_errors_processed_at_exit(conn): + conn.autocommit = True + with pytest.raises(e.UndefinedTable): + with conn.pipeline(): + conn.execute("select * from nosuchtable") + conn.execute("create table voila ()") + cur = conn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = cur.fetchone() + assert count == 0 + + +def test_pipeline(conn): + with conn.pipeline() as p: + c1 = conn.cursor() + c2 = conn.cursor() + c1.execute("select 1") + c2.execute("select 2") + + assert len(p.result_queue) == 2 + + (r1,) = c1.fetchone() + assert r1 == 1 + + (r2,) = c2.fetchone() + assert r2 == 2 + + +def test_autocommit(conn): + conn.autocommit = True + with conn.pipeline(), conn.cursor() as c: + c.execute("select 1") + + (r,) = c.fetchone() + assert r == 1 + + +def test_pipeline_aborted(conn): + conn.autocommit = True + with conn.pipeline() as p: + c1 = conn.execute("select 1") + with pytest.raises(e.UndefinedTable): + conn.execute("select * from doesnotexist").fetchone() + with pytest.raises(e.PipelineAborted): + conn.execute("select 'aborted'").fetchone() + # Sync restore the connection in usable state. + p.sync() + c2 = conn.execute("select 2") + + (r,) = c1.fetchone() + assert r == 1 + + (r,) = c2.fetchone() + assert r == 2 + + +def test_pipeline_commit_aborted(conn): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + conn.commit() + + +def test_sync_syncs_results(conn): + with conn.pipeline() as p: + cur = conn.execute("select 1") + assert cur.statusmessage is None + p.sync() + assert cur.statusmessage == "SELECT 1" + + +def test_sync_syncs_errors(conn): + conn.autocommit = True + with conn.pipeline() as p: + conn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + p.sync() + + +@pipeline_aborted +def test_errors_raised_on_commit(conn): + with conn.pipeline(): + conn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + conn.commit() + conn.rollback() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +def test_errors_raised_on_transaction_exit(conn): + here = False + with conn.pipeline(): + with pytest.raises(e.UndefinedTable): + with conn.transaction(): + conn.execute("select 1 from nosuchtable") + here = True + cur1 = conn.execute("select 1") + assert here + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +def test_errors_raised_on_nested_transaction_exit(conn): + here = False + with conn.pipeline(): + with conn.transaction(): + with pytest.raises(e.UndefinedTable): + with conn.transaction(): + conn.execute("select 1 from nosuchtable") + here = True + cur1 = conn.execute("select 1") + assert here + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +def test_implicit_transaction(conn): + conn.autocommit = True + with conn.pipeline(): + assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE + conn.execute("select 'before'") + # Transaction is ACTIVE because previous command is not completed + # since we have not fetched its results. + assert conn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE + # Upon entering the nested pipeline through "with transaction():", a + # sync() is emitted to restore the transaction state to IDLE, as + # expected to emit a BEGIN. + with conn.transaction(): + conn.execute("select 'tx'") + cur = conn.execute("select 'after'") + assert cur.fetchone() == ("after",) + + +@pytest.mark.crdb_skip("deferrable") +def test_error_on_commit(conn): + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + conn.commit() + + with conn.pipeline(): + conn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + conn.commit() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +def test_fetch_no_result(conn): + with conn.pipeline(): + cur = conn.cursor() + with pytest.raises(e.ProgrammingError): + cur.fetchone() + + +def test_executemany(conn): + conn.autocommit = True + conn.execute("drop table if exists execmanypipeline") + conn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + with conn.pipeline(), conn.cursor() as cur: + cur.executemany( + "insert into execmanypipeline(num) values (%s) returning num", + [(10,), (20,)], + returning=True, + ) + assert cur.rowcount == 2 + assert cur.fetchone() == (10,) + assert cur.nextset() + assert cur.fetchone() == (20,) + assert cur.nextset() is None + + +def test_executemany_no_returning(conn): + conn.autocommit = True + conn.execute("drop table if exists execmanypipelinenoreturning") + conn.execute( + "create unlogged table execmanypipelinenoreturning (" + " id serial primary key, num integer)" + ) + with conn.pipeline(), conn.cursor() as cur: + cur.executemany( + "insert into execmanypipelinenoreturning(num) values (%s)", + [(10,), (20,)], + returning=False, + ) + with pytest.raises(e.ProgrammingError, match="no result available"): + cur.fetchone() + assert cur.nextset() is None + with pytest.raises(e.ProgrammingError, match="no result available"): + cur.fetchone() + assert cur.nextset() is None + + +@pytest.mark.crdb("skip", reason="temp tables") +def test_executemany_trace(conn, trace): + conn.autocommit = True + cur = conn.cursor() + cur.execute("create temp table trace (id int)") + t = trace.trace(conn) + with conn.pipeline(): + cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + conn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] + assert len([i for i in items if i.type == "Sync"]) == 1 + + +@pytest.mark.crdb("skip", reason="temp tables") +def test_executemany_trace_returning(conn, trace): + conn.autocommit = True + cur = conn.cursor() + cur.execute("create temp table trace (id int)") + t = trace.trace(conn) + with conn.pipeline(): + cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + conn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] * 3 + assert items[-2].direction == "F" # last 2 items are F B + assert len([i for i in items if i.type == "Sync"]) == 1 + + +def test_prepared(conn): + conn.autocommit = True + with conn.pipeline(): + c1 = conn.execute("select %s::int", [10], prepare=True) + c2 = conn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + + (r,) = c1.fetchone() + assert r == 10 + + (r,) = c2.fetchone() + assert r == 1 + + +def test_auto_prepare(conn): + conn.autocommit = True + conn.prepared_threshold = 5 + with conn.pipeline(): + cursors = [ + conn.execute("select count(*) from pg_prepared_statements where name != ''") + for i in range(10) + ] + + assert len(conn._prepared._names) == 1 + + res = [c.fetchone()[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +def test_transaction(conn): + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 'tx'") + + (r,) = cur.fetchone() + assert r == "tx" + + with conn.transaction(): + cur = conn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = cur.fetchone() + assert r == "rb" + + assert not notices + + +def test_transaction_nested(conn): + with conn.pipeline(): + with conn.transaction(): + outer = conn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + inner = conn.execute("select 'inner'") + 1 / 0 + + (r,) = outer.fetchone() + assert r == "outer" + (r,) = inner.fetchone() + assert r == "inner" + + +def test_transaction_nested_no_statement(conn): + with conn.pipeline(): + with conn.transaction(): + with conn.transaction(): + cur = conn.execute("select 1") + + (r,) = cur.fetchone() + assert r == 1 + + +def test_outer_transaction(conn): + with conn.transaction(): + conn.execute("drop table if exists outertx") + with conn.transaction(): + with conn.pipeline(): + conn.execute("create table outertx as (select 1)") + cur = conn.execute("select * from outertx") + (r,) = cur.fetchone() + assert r == 1 + cur = conn.execute("select count(*) from pg_tables where tablename = 'outertx'") + assert cur.fetchone()[0] == 1 + + +def test_outer_transaction_error(conn): + with conn.transaction(): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + with conn.pipeline(): + conn.execute("select error") + conn.execute("create table voila ()") + + +def test_rollback_explicit(conn): + conn.autocommit = True + with conn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.rollback() + conn.execute("select 1") + + +def test_rollback_transaction(conn): + conn.autocommit = True + with pytest.raises(e.DivisionByZero): + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.execute("select 1") + + +def test_message_0x33(conn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + conn.autocommit = True + with conn.pipeline(): + cur = conn.execute("select 'test'") + assert cur.fetchone() == ("test",) + + assert not notices + + +def test_transaction_state_implicit_begin(conn, trace): + # Regression test to ensure that the transaction state is correct after + # the implicit BEGIN statement (in non-autocommit mode). + notices = [] + conn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + t = trace.trace(conn) + with conn.pipeline(): + conn.execute("select 'x'").fetchone() + conn.execute("select 'y'") + assert not notices + assert [ + e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0] + ] == [b' "" "BEGIN" 0'] + + +def test_concurrency(conn): + with conn.transaction(): + conn.execute("drop table if exists pipeline_concurrency") + conn.execute("drop table if exists accessed") + with conn.transaction(): + conn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + conn.execute("create unlogged table accessed as (select now() as value)") + + def update(value): + cur = conn.execute( + "insert into pipeline_concurrency(value) values (%s) returning value", + (value,), + ) + conn.execute("update accessed set value = now()") + return cur + + conn.autocommit = True + + (before,) = conn.execute("select value from accessed").fetchone() + + values = range(1, 10) + with conn.pipeline(): + with concurrent.futures.ThreadPoolExecutor() as e: + cursors = e.map(update, values, timeout=len(values)) + assert sum(cur.fetchone()[0] for cur in cursors) == sum(values) + + (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone() + assert s == sum(values) + (after,) = conn.execute("select value from accessed").fetchone() + assert after > before diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py new file mode 100644 index 0000000..2e743cf --- /dev/null +++ b/tests/test_pipeline_async.py @@ -0,0 +1,586 @@ +import asyncio +import logging +from typing import Any +from operator import attrgetter +from itertools import groupby + +import pytest + +import psycopg +from psycopg import pq +from psycopg import errors as e + +from .test_pipeline import pipeline_aborted + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.pipeline, + pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"), +] + + +async def test_repr(aconn): + async with aconn.pipeline() as p: + assert "psycopg.AsyncPipeline" in repr(p) + assert "[IDLE, pipeline=ON]" in repr(p) + + await aconn.close() + assert "[BAD]" in repr(p) + + +async def test_connection_closed(aconn): + await aconn.close() + with pytest.raises(e.OperationalError): + async with aconn.pipeline(): + pass + + +async def test_pipeline_status(aconn: psycopg.AsyncConnection[Any]) -> None: + assert aconn._pipeline is None + async with aconn.pipeline() as p: + assert aconn._pipeline is p + assert p.status == pq.PipelineStatus.ON + assert p.status == pq.PipelineStatus.OFF + assert not aconn._pipeline + + +async def test_pipeline_reenter(aconn: psycopg.AsyncConnection[Any]) -> None: + async with aconn.pipeline() as p1: + async with aconn.pipeline() as p2: + assert p2 is p1 + assert p1.status == pq.PipelineStatus.ON + assert p2 is p1 + assert p2.status == pq.PipelineStatus.ON + assert aconn._pipeline is None + assert p1.status == pq.PipelineStatus.OFF + + +async def test_pipeline_broken_conn_exit(aconn: psycopg.AsyncConnection[Any]) -> None: + with pytest.raises(e.OperationalError): + async with aconn.pipeline(): + await aconn.execute("select 1") + await aconn.close() + closed = True + + assert closed + + +async def test_pipeline_exit_error_noclobber(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + async with aconn.pipeline(): + await aconn.close() + 1 / 0 + + assert len(caplog.records) == 1 + + +async def test_pipeline_exit_error_noclobber_nested(aconn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + with pytest.raises(ZeroDivisionError): + async with aconn.pipeline(): + async with aconn.pipeline(): + await aconn.close() + 1 / 0 + + assert len(caplog.records) == 2 + + +async def test_pipeline_exit_sync_trace(aconn, trace): + t = trace.trace(aconn) + async with aconn.pipeline(): + pass + await aconn.close() + assert len([i for i in t if i.type == "Sync"]) == 1 + + +async def test_pipeline_nested_sync_trace(aconn, trace): + t = trace.trace(aconn) + async with aconn.pipeline(): + async with aconn.pipeline(): + pass + await aconn.close() + assert len([i for i in t if i.type == "Sync"]) == 2 + + +async def test_cursor_stream(aconn): + async with aconn.pipeline(), aconn.cursor() as cur: + with pytest.raises(psycopg.ProgrammingError): + await cur.stream("select 1").__anext__() + + +async def test_server_cursor(aconn): + async with aconn.cursor(name="pipeline") as cur, aconn.pipeline(): + with pytest.raises(psycopg.NotSupportedError): + await cur.execute("select 1") + + +async def test_cannot_insert_multiple_commands(aconn): + with pytest.raises((e.SyntaxError, e.InvalidPreparedStatementDefinition)): + async with aconn.pipeline(): + await aconn.execute("select 1; select 2") + + +async def test_copy(aconn): + async with aconn.pipeline(): + cur = aconn.cursor() + with pytest.raises(e.NotSupportedError): + async with cur.copy("copy (select 1) to stdout") as copy: + await copy.read() + + +async def test_pipeline_processed_at_exit(aconn): + async with aconn.cursor() as cur: + async with aconn.pipeline() as p: + await cur.execute("select 1") + + assert len(p.result_queue) == 1 + + assert await cur.fetchone() == (1,) + + +async def test_pipeline_errors_processed_at_exit(aconn): + await aconn.set_autocommit(True) + with pytest.raises(e.UndefinedTable): + async with aconn.pipeline(): + await aconn.execute("select * from nosuchtable") + await aconn.execute("create table voila ()") + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = %s", ("voila",) + ) + (count,) = await cur.fetchone() + assert count == 0 + + +async def test_pipeline(aconn): + async with aconn.pipeline() as p: + c1 = aconn.cursor() + c2 = aconn.cursor() + await c1.execute("select 1") + await c2.execute("select 2") + + assert len(p.result_queue) == 2 + + (r1,) = await c1.fetchone() + assert r1 == 1 + + (r2,) = await c2.fetchone() + assert r2 == 2 + + +async def test_autocommit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(), aconn.cursor() as c: + await c.execute("select 1") + + (r,) = await c.fetchone() + assert r == 1 + + +async def test_pipeline_aborted(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline() as p: + c1 = await aconn.execute("select 1") + with pytest.raises(e.UndefinedTable): + await (await aconn.execute("select * from doesnotexist")).fetchone() + with pytest.raises(e.PipelineAborted): + await (await aconn.execute("select 'aborted'")).fetchone() + # Sync restore the connection in usable state. + await p.sync() + c2 = await aconn.execute("select 2") + + (r,) = await c1.fetchone() + assert r == 1 + + (r,) = await c2.fetchone() + assert r == 2 + + +async def test_pipeline_commit_aborted(aconn): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + await aconn.commit() + + +async def test_sync_syncs_results(aconn): + async with aconn.pipeline() as p: + cur = await aconn.execute("select 1") + assert cur.statusmessage is None + await p.sync() + assert cur.statusmessage == "SELECT 1" + + +async def test_sync_syncs_errors(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline() as p: + await aconn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + await p.sync() + + +@pipeline_aborted +async def test_errors_raised_on_commit(aconn): + async with aconn.pipeline(): + await aconn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + await aconn.commit() + await aconn.rollback() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +async def test_errors_raised_on_transaction_exit(aconn): + here = False + async with aconn.pipeline(): + with pytest.raises(e.UndefinedTable): + async with aconn.transaction(): + await aconn.execute("select 1 from nosuchtable") + here = True + cur1 = await aconn.execute("select 1") + assert here + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +@pytest.mark.flakey("assert fails randomly in CI blocking release") +async def test_errors_raised_on_nested_transaction_exit(aconn): + here = False + async with aconn.pipeline(): + async with aconn.transaction(): + with pytest.raises(e.UndefinedTable): + async with aconn.transaction(): + await aconn.execute("select 1 from nosuchtable") + here = True + cur1 = await aconn.execute("select 1") + assert here + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +async def test_implicit_transaction(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE + await aconn.execute("select 'before'") + # Transaction is ACTIVE because previous command is not completed + # since we have not fetched its results. + assert aconn.pgconn.transaction_status == pq.TransactionStatus.ACTIVE + # Upon entering the nested pipeline through "with transaction():", a + # sync() is emitted to restore the transaction state to IDLE, as + # expected to emit a BEGIN. + async with aconn.transaction(): + await aconn.execute("select 'tx'") + cur = await aconn.execute("select 'after'") + assert await cur.fetchone() == ("after",) + + +@pytest.mark.crdb_skip("deferrable") +async def test_error_on_commit(aconn): + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + await aconn.commit() + + async with aconn.pipeline(): + await aconn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + await aconn.commit() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert (await cur1.fetchone()) == (1,) + assert (await cur2.fetchone()) == (2,) + + +async def test_fetch_no_result(aconn): + async with aconn.pipeline(): + cur = aconn.cursor() + with pytest.raises(e.ProgrammingError): + await cur.fetchone() + + +async def test_executemany(aconn): + await aconn.set_autocommit(True) + await aconn.execute("drop table if exists execmanypipeline") + await aconn.execute( + "create unlogged table execmanypipeline (" + " id serial primary key, num integer)" + ) + async with aconn.pipeline(), aconn.cursor() as cur: + await cur.executemany( + "insert into execmanypipeline(num) values (%s) returning num", + [(10,), (20,)], + returning=True, + ) + assert cur.rowcount == 2 + assert (await cur.fetchone()) == (10,) + assert cur.nextset() + assert (await cur.fetchone()) == (20,) + assert cur.nextset() is None + + +async def test_executemany_no_returning(aconn): + await aconn.set_autocommit(True) + await aconn.execute("drop table if exists execmanypipelinenoreturning") + await aconn.execute( + "create unlogged table execmanypipelinenoreturning (" + " id serial primary key, num integer)" + ) + async with aconn.pipeline(), aconn.cursor() as cur: + await cur.executemany( + "insert into execmanypipelinenoreturning(num) values (%s)", + [(10,), (20,)], + returning=False, + ) + with pytest.raises(e.ProgrammingError, match="no result available"): + await cur.fetchone() + assert cur.nextset() is None + with pytest.raises(e.ProgrammingError, match="no result available"): + await cur.fetchone() + assert cur.nextset() is None + + +@pytest.mark.crdb("skip", reason="temp tables") +async def test_executemany_trace(aconn, trace): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("create temp table trace (id int)") + t = trace.trace(aconn) + async with aconn.pipeline(): + await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + await cur.executemany("insert into trace (id) values (%s)", [(10,), (20,)]) + await aconn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] + assert len([i for i in items if i.type == "Sync"]) == 1 + + +@pytest.mark.crdb("skip", reason="temp tables") +async def test_executemany_trace_returning(aconn, trace): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("create temp table trace (id int)") + t = trace.trace(aconn) + async with aconn.pipeline(): + await cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + await cur.executemany( + "insert into trace (id) values (%s)", [(10,), (20,)], returning=True + ) + await aconn.close() + items = list(t) + assert items[-1].type == "Terminate" + del items[-1] + roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))] + assert roundtrips == ["F", "B"] * 3 + assert items[-2].direction == "F" # last 2 items are F B + assert len([i for i in items if i.type == "Sync"]) == 1 + + +async def test_prepared(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + c1 = await aconn.execute("select %s::int", [10], prepare=True) + c2 = await aconn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + + (r,) = await c1.fetchone() + assert r == 10 + + (r,) = await c2.fetchone() + assert r == 1 + + +async def test_auto_prepare(aconn): + aconn.prepared_threshold = 5 + async with aconn.pipeline(): + cursors = [ + await aconn.execute( + "select count(*) from pg_prepared_statements where name != ''" + ) + for i in range(10) + ] + + assert len(aconn._prepared._names) == 1 + + res = [(await c.fetchone())[0] for c in cursors] + assert res == [0] * 5 + [1] * 5 + + +async def test_transaction(aconn): + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 'tx'") + + (r,) = await cur.fetchone() + assert r == "tx" + + async with aconn.transaction(): + cur = await aconn.execute("select 'rb'") + raise psycopg.Rollback() + + (r,) = await cur.fetchone() + assert r == "rb" + + assert not notices + + +async def test_transaction_nested(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + outer = await aconn.execute("select 'outer'") + with pytest.raises(ZeroDivisionError): + async with aconn.transaction(): + inner = await aconn.execute("select 'inner'") + 1 / 0 + + (r,) = await outer.fetchone() + assert r == "outer" + (r,) = await inner.fetchone() + assert r == "inner" + + +async def test_transaction_nested_no_statement(aconn): + async with aconn.pipeline(): + async with aconn.transaction(): + async with aconn.transaction(): + cur = await aconn.execute("select 1") + + (r,) = await cur.fetchone() + assert r == 1 + + +async def test_outer_transaction(aconn): + async with aconn.transaction(): + await aconn.execute("drop table if exists outertx") + async with aconn.transaction(): + async with aconn.pipeline(): + await aconn.execute("create table outertx as (select 1)") + cur = await aconn.execute("select * from outertx") + (r,) = await cur.fetchone() + assert r == 1 + cur = await aconn.execute( + "select count(*) from pg_tables where tablename = 'outertx'" + ) + assert (await cur.fetchone())[0] == 1 + + +async def test_outer_transaction_error(aconn): + async with aconn.transaction(): + with pytest.raises((e.UndefinedColumn, e.OperationalError)): + async with aconn.pipeline(): + await aconn.execute("select error") + await aconn.execute("create table voila ()") + + +async def test_rollback_explicit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.rollback() + await aconn.execute("select 1") + + +async def test_rollback_transaction(aconn): + await aconn.set_autocommit(True) + with pytest.raises(e.DivisionByZero): + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.execute("select 1") + + +async def test_message_0x33(aconn): + # https://github.com/psycopg/psycopg/issues/314 + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + + await aconn.set_autocommit(True) + async with aconn.pipeline(): + cur = await aconn.execute("select 'test'") + assert (await cur.fetchone()) == ("test",) + + assert not notices + + +async def test_transaction_state_implicit_begin(aconn, trace): + # Regression test to ensure that the transaction state is correct after + # the implicit BEGIN statement (in non-autocommit mode). + notices = [] + aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary)) + t = trace.trace(aconn) + async with aconn.pipeline(): + await (await aconn.execute("select 'x'")).fetchone() + await aconn.execute("select 'y'") + assert not notices + assert [ + e.content[0] for e in t if e.type == "Parse" and b"BEGIN" in e.content[0] + ] == [b' "" "BEGIN" 0'] + + +async def test_concurrency(aconn): + async with aconn.transaction(): + await aconn.execute("drop table if exists pipeline_concurrency") + await aconn.execute("drop table if exists accessed") + async with aconn.transaction(): + await aconn.execute( + "create unlogged table pipeline_concurrency (" + " id serial primary key," + " value integer" + ")" + ) + await aconn.execute("create unlogged table accessed as (select now() as value)") + + async def update(value): + cur = await aconn.execute( + "insert into pipeline_concurrency(value) values (%s) returning value", + (value,), + ) + await aconn.execute("update accessed set value = now()") + return cur + + await aconn.set_autocommit(True) + + (before,) = await (await aconn.execute("select value from accessed")).fetchone() + + values = range(1, 10) + async with aconn.pipeline(): + cursors = await asyncio.wait_for( + asyncio.gather(*[update(value) for value in values]), + timeout=len(values), + ) + + assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) + + (s,) = await ( + await aconn.execute("select sum(value) from pipeline_concurrency") + ).fetchone() + assert s == sum(values) + (after,) = await (await aconn.execute("select value from accessed")).fetchone() + assert after > before diff --git a/tests/test_prepared.py b/tests/test_prepared.py new file mode 100644 index 0000000..56c580a --- /dev/null +++ b/tests/test_prepared.py @@ -0,0 +1,277 @@ +""" +Prepared statements tests +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg.rows import namedtuple_row + + +@pytest.mark.parametrize("value", [None, 0, 3]) +def test_prepare_threshold_init(conn_cls, dsn, value): + with conn_cls.connect(dsn, prepare_threshold=value) as conn: + assert conn.prepare_threshold == value + + +def test_dont_prepare(conn): + cur = conn.cursor() + for i in range(10): + cur.execute("select %s::int", [i], prepare=False) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_do_prepare(conn): + cur = conn.cursor() + cur.execute("select %s::int", [10], prepare=True) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_auto_prepare(conn): + res = [] + for i in range(10): + conn.execute("select %s::int", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +def test_dont_prepare_conn(conn): + for i in range(10): + conn.execute("select %s::int", [i], prepare=False) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_do_prepare_conn(conn): + conn.execute("select %s::int", [10], prepare=True) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_auto_prepare_conn(conn): + res = [] + for i in range(10): + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +def test_prepare_disable(conn): + conn.prepare_threshold = None + res = [] + for i in range(10): + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 10 + assert not conn._prepared._names + assert not conn._prepared._counts + + +def test_no_prepare_multi(conn): + res = [] + for i in range(10): + conn.execute("select 1; select 2") + stmts = get_prepared_statements(conn) + res.append(len(stmts)) + + assert res == [0] * 10 + + +def test_no_prepare_multi_with_drop(conn): + conn.execute("select 1", prepare=True) + + for i in range(10): + conn.execute("drop table if exists noprep; create table noprep()") + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +def test_no_prepare_error(conn): + conn.autocommit = True + for i in range(10): + with pytest.raises(conn.ProgrammingError): + conn.execute("select wat") + + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")), + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +def test_misc_statement(conn, query): + conn.execute("create table prepared_test (num int)", prepare=False) + conn.prepare_threshold = 0 + conn.execute(query) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + + +def test_params_types(conn): + conn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + stmts = get_prepared_statements(conn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] + + +def test_evict_lru(conn): + conn.prepared_max = 5 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared._names) == 1 + assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert conn._prepared._counts[f"select {i}".encode(), ()] == 1 + + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" + + +def test_evict_lru_deallocate(conn): + conn.prepared_max = 5 + conn.prepare_threshold = 0 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared._names) == 5 + for j in [9, 8, 7, 6, "'a'"]: + name = conn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") + + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] + + +def test_different_types(conn): + conn.prepare_threshold = 0 + conn.execute("select %s", [None]) + conn.execute("select %s", [dt.date(2000, 1, 1)]) + conn.execute("select %s", [42]) + conn.execute("select %s", [41]) + conn.execute("select %s", [dt.date(2000, 1, 2)]) + + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] + + +def test_untyped_json(conn): + conn.prepare_threshold = 1 + conn.execute("create table testjson(data jsonb)") + + for i in range(2): + conn.execute("insert into testjson (data) values (%s)", ["{}"]) + + stmts = get_prepared_statements(conn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +def test_change_type_execute(conn): + conn.prepare_threshold = 0 + for i in range(3): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + conn.rollback() + + +def test_change_type_executemany(conn): + for i in range(3): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().executemany( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], + ) + conn.rollback() + + +@pytest.mark.crdb("skip", reason="can't re-create a type") +def test_change_type(conn): + conn.prepare_threshold = 0 + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + conn.execute("DROP TABLE preptable") + conn.execute("DROP TYPE prepenum") + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + + stmts = get_prepared_statements(conn) + assert len(stmts) == 3 + + +def test_change_type_savepoint(conn): + conn.prepare_threshold = 0 + with conn.transaction(): + for i in range(3): + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") + conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])") + conn.cursor().execute( + "INSERT INTO preptable (bar) VALUES (%(enum_col)s::prepenum[])", + {"enum_col": ["foo"]}, + ) + raise ZeroDivisionError() + + +def get_prepared_statements(conn): + cur = conn.cursor(row_factory=namedtuple_row) + cur.execute( + # CRDB has 'PREPARE name AS' in the statement. + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return cur.fetchall() diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py new file mode 100644 index 0000000..84d948f --- /dev/null +++ b/tests/test_prepared_async.py @@ -0,0 +1,207 @@ +""" +Prepared statements tests on async connections +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.parametrize("value", [None, 0, 3]) +async def test_prepare_threshold_init(aconn_cls, dsn, value): + async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn: + assert conn.prepare_threshold == value + + +async def test_dont_prepare(aconn): + cur = aconn.cursor() + for i in range(10): + await cur.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s::int", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_dont_prepare_conn(aconn): + for i in range(10): + await aconn.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare_conn(aconn): + await aconn.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare_conn(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_prepare_disable(aconn): + aconn.prepare_threshold = None + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + assert not aconn._prepared._names + assert not aconn._prepared._counts + + +async def test_no_prepare_multi(aconn): + res = [] + for i in range(10): + await aconn.execute("select 1; select 2") + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + + +async def test_no_prepare_error(aconn): + await aconn.set_autocommit(True) + for i in range(10): + with pytest.raises(aconn.ProgrammingError): + await aconn.execute("select wat") + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")), + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +async def test_misc_statement(aconn, query): + await aconn.execute("create table prepared_test (num int)", prepare=False) + aconn.prepare_threshold = 0 + await aconn.execute(query) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_params_types(aconn): + await aconn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + stmts = await get_prepared_statements(aconn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] + + +async def test_evict_lru(aconn): + aconn.prepared_max = 5 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 1 + assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1 + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" + + +async def test_evict_lru_deallocate(aconn): + aconn.prepared_max = 5 + aconn.prepare_threshold = 0 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 5 + for j in [9, 8, 7, 6, "'a'"]: + name = aconn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] + + +async def test_different_types(aconn): + aconn.prepare_threshold = 0 + await aconn.execute("select %s", [None]) + await aconn.execute("select %s", [dt.date(2000, 1, 1)]) + await aconn.execute("select %s", [42]) + await aconn.execute("select %s", [41]) + await aconn.execute("select %s", [dt.date(2000, 1, 2)]) + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] + + +async def test_untyped_json(aconn): + aconn.prepare_threshold = 1 + await aconn.execute("create table testjson(data jsonb)") + for i in range(2): + await aconn.execute("insert into testjson (data) values (%s)", ["{}"]) + + stmts = await get_prepared_statements(aconn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +async def get_prepared_statements(aconn): + cur = aconn.cursor(row_factory=namedtuple_row) + await cur.execute( + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return await cur.fetchall() diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py new file mode 100644 index 0000000..82a5d73 --- /dev/null +++ b/tests/test_psycopg_dbapi20.py @@ -0,0 +1,164 @@ +import pytest +import datetime as dt +from typing import Any, Dict + +import psycopg +from psycopg.conninfo import conninfo_to_dict + +from . import dbapi20 +from . import dbapi20_tpc + + +@pytest.fixture(scope="class") +def with_dsn(request, session_dsn): + request.cls.connect_args = (session_dsn,) + + +@pytest.mark.usefixtures("with_dsn") +class PsycopgTests(dbapi20.DatabaseAPI20Test): + driver = psycopg + # connect_args = () # set by the fixture + connect_kw_args: Dict[str, Any] = {} + + def test_nextset(self): + # tested elsewhere + pass + + def test_setoutputsize(self): + # no-op + pass + + +@pytest.mark.usefixtures("tpc") +@pytest.mark.usefixtures("with_dsn") +class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests): + driver = psycopg + connect_args = () # set by the fixture + + def connect(self): + return psycopg.connect(*self.connect_args) + + +# Shut up warnings +PsycopgTests.failUnless = PsycopgTests.assertTrue +PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual + + +@pytest.mark.parametrize( + "typename, singleton", + [ + ("bytea", "BINARY"), + ("date", "DATETIME"), + ("timestamp without time zone", "DATETIME"), + ("timestamp with time zone", "DATETIME"), + ("time without time zone", "DATETIME"), + ("time with time zone", "DATETIME"), + ("interval", "DATETIME"), + ("integer", "NUMBER"), + ("smallint", "NUMBER"), + ("bigint", "NUMBER"), + ("real", "NUMBER"), + ("double precision", "NUMBER"), + ("numeric", "NUMBER"), + ("decimal", "NUMBER"), + ("oid", "ROWID"), + ("varchar", "STRING"), + ("char", "STRING"), + ("text", "STRING"), + ], +) +def test_singletons(conn, typename, singleton): + singleton = getattr(psycopg, singleton) + cur = conn.cursor() + cur.execute(f"select null::{typename}") + oid = cur.description[0].type_code + assert singleton == oid + assert oid == singleton + assert singleton != oid + 10000 + assert oid + 10000 != singleton + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01T00:00:00.000000+0000"), + (1273173119.99992, "2010-05-06T14:11:59.999920-0500"), + ], +) +def test_timestamp_from_ticks(ticks, want): + s = psycopg.TimestampFromTicks(ticks) + want = dt.datetime.strptime(want, "%Y-%m-%dT%H:%M:%S.%f%z") + assert s == want + + +@pytest.mark.parametrize( + "ticks, want", + [ + (0, "1970-01-01"), + # Returned date is local + (1273173119.99992, ["2010-05-06", "2010-05-07"]), + ], +) +def test_date_from_ticks(ticks, want): + s = psycopg.DateFromTicks(ticks) + if isinstance(want, str): + want = [want] + want = [dt.datetime.strptime(w, "%Y-%m-%d").date() for w in want] + assert s in want + + +@pytest.mark.parametrize( + "ticks, want", + [(0, "00:00:00.000000"), (1273173119.99992, "00:11:59.999920")], +) +def test_time_from_ticks(ticks, want): + s = psycopg.TimeFromTicks(ticks) + want = dt.datetime.strptime(want, "%H:%M:%S.%f").time() + assert s.replace(hour=0) == want + + +@pytest.mark.parametrize( + "args, kwargs, want", + [ + ((), {}, ""), + (("",), {}, ""), + (("host=foo user=bar",), {}, "host=foo user=bar"), + (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + ( + ("host=foo port=5432",), + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + (("host=foo",), {"user": None}, "host=foo"), + ], +) +def test_connect_args(monkeypatch, pgconn, args, kwargs, want): + the_conninfo: str + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + conn = psycopg.connect(*args, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + conn.close() + + +@pytest.mark.parametrize( + "args, kwargs, exctype", + [ + (("host=foo", "host=bar"), {}, TypeError), + (("", ""), {}, TypeError), + ((), {"nosuchparam": 42}, psycopg.ProgrammingError), + ], +) +def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + with pytest.raises(exctype): + psycopg.connect(*args, **kwargs) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..7263a80 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,162 @@ +import pytest + +import psycopg +from psycopg import pq +from psycopg.adapt import Transformer, PyFormat +from psycopg._queries import PostgresQuery, _split_query + + +@pytest.mark.parametrize( + "input, want", + [ + (b"", [(b"", 0, PyFormat.AUTO)]), + (b"foo bar", [(b"foo bar", 0, PyFormat.AUTO)]), + (b"foo %% bar", [(b"foo % bar", 0, PyFormat.AUTO)]), + (b"%s", [(b"", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]), + (b"%s foo", [(b"", 0, PyFormat.AUTO), (b" foo", 0, PyFormat.AUTO)]), + (b"%b foo", [(b"", 0, PyFormat.BINARY), (b" foo", 0, PyFormat.AUTO)]), + (b"foo %s", [(b"foo ", 0, PyFormat.AUTO), (b"", 0, PyFormat.AUTO)]), + ( + b"foo %%%s bar", + [(b"foo %", 0, PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)], + ), + ( + b"foo %(name)s bar", + [(b"foo ", "name", PyFormat.AUTO), (b" bar", 0, PyFormat.AUTO)], + ), + ( + b"foo %(name)s %(name)b bar", + [ + (b"foo ", "name", PyFormat.AUTO), + (b" ", "name", PyFormat.BINARY), + (b" bar", 0, PyFormat.AUTO), + ], + ), + ( + b"foo %s%b bar %s baz", + [ + (b"foo ", 0, PyFormat.AUTO), + (b"", 1, PyFormat.BINARY), + (b" bar ", 2, PyFormat.AUTO), + (b" baz", 0, PyFormat.AUTO), + ], + ), + ], +) +def test_split_query(input, want): + assert _split_query(input) == want + + +@pytest.mark.parametrize( + "input", + [ + b"foo %d bar", + b"foo % bar", + b"foo %%% bar", + b"foo %(foo)d bar", + b"foo %(foo)s bar %s baz", + b"foo %(foo) bar", + b"foo %(foo bar", + b"3%2", + ], +) +def test_split_query_bad(input): + with pytest.raises(psycopg.ProgrammingError): + _split_query(input) + + +@pytest.mark.parametrize( + "query, params, want, wformats, wparams", + [ + (b"", None, b"", None, None), + (b"", [], b"", [], []), + (b"%%", [], b"%", [], []), + (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]), + ( + b"%t %% %t", + (1, 2), + b"$1 % $2", + [pq.Format.TEXT, pq.Format.TEXT], + [b"1", b"2"], + ), + ( + b"%t %% %t", + ("a", 2), + b"$1 % $2", + [pq.Format.TEXT, pq.Format.TEXT], + [b"a", b"2"], + ), + ], +) +def test_pg_query_seq(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams + + +@pytest.mark.parametrize( + "query, params, want, wformats, wparams", + [ + (b"", {}, b"", [], []), + (b"hello %%", {"a": 1}, b"hello %", [], []), + ( + b"select %(hello)t", + {"hello": 1, "world": 2}, + b"select $1", + [pq.Format.TEXT], + [b"1"], + ), + ( + b"select %(hi)s %(there)s %(hi)s", + {"hi": 0, "there": "a"}, + b"select $1 $2 $1", + [pq.Format.BINARY, pq.Format.TEXT], + [b"\x00" * 2, b"a"], + ), + ], +) +def test_pg_query_map(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams + + +@pytest.mark.parametrize( + "query, params", + [ + (b"select %s", {"a": 1}), + (b"select %(name)s", [1]), + (b"select %s", "a"), + (b"select %s", 1), + (b"select %s", b"a"), + (b"select %s", set()), + ], +) +def test_pq_query_badtype(query, params): + pq = PostgresQuery(Transformer()) + with pytest.raises(TypeError): + pq.convert(query, params) + + +@pytest.mark.parametrize( + "query, params", + [ + (b"", [1]), + (b"%s", []), + (b"%%", [1]), + (b"$1", [1]), + (b"select %(", {"a": 1}), + (b"select %(a", {"a": 1}), + (b"select %(a)", {"a": 1}), + (b"select %s %(hi)s", [1]), + (b"select %(hi)s %(hi)b", {"hi": 1}), + ], +) +def test_pq_query_badprog(query, params): + pq = PostgresQuery(Transformer()) + with pytest.raises(psycopg.ProgrammingError): + pq.convert(query, params) diff --git a/tests/test_rows.py b/tests/test_rows.py new file mode 100644 index 0000000..5165b80 --- /dev/null +++ b/tests/test_rows.py @@ -0,0 +1,167 @@ +import pytest + +import psycopg +from psycopg import rows + +from .utils import eur + + +def test_tuple_row(conn): + conn.row_factory = rows.dict_row + assert conn.execute("select 1 as a").fetchone() == {"a": 1} + cur = conn.cursor(row_factory=rows.tuple_row) + row = cur.execute("select 1 as a").fetchone() + assert row == (1,) + assert type(row) is tuple + assert cur._make_row is tuple + + +def test_dict_row(conn): + cur = conn.cursor(row_factory=rows.dict_row) + cur.execute("select 'bob' as name, 3 as id") + assert cur.fetchall() == [{"name": "bob", "id": 3}] + + cur.execute("select 'a' as letter; select 1 as number") + assert cur.fetchall() == [{"letter": "a"}] + assert cur.nextset() + assert cur.fetchall() == [{"number": 1}] + assert not cur.nextset() + + +def test_namedtuple_row(conn): + rows._make_nt.cache_clear() + cur = conn.cursor(row_factory=rows.namedtuple_row) + cur.execute("select 'bob' as name, 3 as id") + (person1,) = cur.fetchall() + assert f"{person1.name} {person1.id}" == "bob 3" + + ci1 = rows._make_nt.cache_info() + assert ci1.hits == 0 and ci1.misses == 1 + + cur.execute("select 'alice' as name, 1 as id") + (person2,) = cur.fetchall() + assert type(person2) is type(person1) + + ci2 = rows._make_nt.cache_info() + assert ci2.hits == 1 and ci2.misses == 1 + + cur.execute("select 'foo', 1 as id") + (r0,) = cur.fetchall() + assert r0.f_column_ == "foo" + assert r0.id == 1 + + cur.execute("select 'a' as letter; select 1 as number") + (r1,) = cur.fetchall() + assert r1.letter == "a" + assert cur.nextset() + (r2,) = cur.fetchall() + assert r2.number == 1 + assert not cur.nextset() + assert type(r1) is not type(r2) + + cur.execute(f'select 1 as üåäö, 2 as _, 3 as "123", 4 as "a-b", 5 as "{eur}eur"') + (r3,) = cur.fetchall() + assert r3.üåäö == 1 + assert r3.f_ == 2 + assert r3.f123 == 3 + assert r3.a_b == 4 + assert r3.f_eur == 5 + + +def test_class_row(conn): + cur = conn.cursor(row_factory=rows.class_row(Person)) + cur.execute("select 'John' as first, 'Doe' as last") + (p,) = cur.fetchall() + assert isinstance(p, Person) + assert p.first == "John" + assert p.last == "Doe" + assert p.age is None + + for query in ( + "select 'John' as first", + "select 'John' as first, 'Doe' as last, 42 as wat", + ): + cur.execute(query) + with pytest.raises(TypeError): + cur.fetchone() + + +def test_args_row(conn): + cur = conn.cursor(row_factory=rows.args_row(argf)) + cur.execute("select 'John' as first, 'Doe' as last") + assert cur.fetchone() == "JohnDoe" + + +def test_kwargs_row(conn): + cur = conn.cursor(row_factory=rows.kwargs_row(kwargf)) + cur.execute("select 'John' as first, 'Doe' as last") + (p,) = cur.fetchall() + assert isinstance(p, Person) + assert p.first == "John" + assert p.last == "Doe" + assert p.age == 42 + + +@pytest.mark.parametrize( + "factory", + "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(), +) +def test_no_result(factory, conn): + cur = conn.cursor(row_factory=factory_from_name(factory)) + cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() + + +@pytest.mark.crdb_skip("no col query") +@pytest.mark.parametrize( + "factory", "tuple_row dict_row namedtuple_row args_row".split() +) +def test_no_column(factory, conn): + cur = conn.cursor(row_factory=factory_from_name(factory)) + cur.execute("select") + recs = cur.fetchall() + assert len(recs) == 1 + assert not recs[0] + + +@pytest.mark.crdb("skip") +def test_no_column_class_row(conn): + class Empty: + def __init__(self, x=10, y=20): + self.x = x + self.y = y + + cur = conn.cursor(row_factory=rows.class_row(Empty)) + cur.execute("select") + x = cur.fetchone() + assert isinstance(x, Empty) + assert x.x == 10 + assert x.y == 20 + + +def factory_from_name(name): + factory = getattr(rows, name) + if factory is rows.class_row: + factory = factory(Person) + if factory is rows.args_row: + factory = factory(argf) + if factory is rows.kwargs_row: + factory = factory(argf) + + return factory + + +class Person: + def __init__(self, first, last, age=None): + self.first = first + self.last = last + self.age = age + + +def argf(*args): + return "".join(map(str, args)) + + +def kwargf(**kwargs): + return Person(**kwargs, age=42) diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py new file mode 100644 index 0000000..f7b6c8e --- /dev/null +++ b/tests/test_server_cursor.py @@ -0,0 +1,525 @@ +import pytest + +import psycopg +from psycopg import rows, errors as e +from psycopg.pq import Format + +pytestmark = pytest.mark.crdb_skip("server-side cursor") + + +def test_init_row_factory(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is conn + assert cur.row_factory is conn.row_factory + + conn.row_factory = rows.dict_row + + with psycopg.ServerCursor(conn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +def test_init_params(conn): + with psycopg.ServerCursor(conn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + +@pytest.mark.crdb_skip("cursor invalid name") +def test_funny_name(conn): + cur = conn.cursor("1-2-3") + cur.execute("select generate_series(1, 3) as bar") + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + cur.close() + + +def test_repr(conn): + cur = conn.cursor("my-name") + assert "psycopg.ServerCursor" in str(cur) + assert "my-name" in repr(cur) + cur.close() + + +def test_connection(conn): + cur = conn.cursor("foo") + assert cur.connection is conn + cur.close() + + +def test_description(conn): + cur = conn.cursor("foo") + assert cur.name == "foo" + cur.execute("select generate_series(1, 10)::int4 as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 + cur.close() + + +def test_format(conn): + cur = conn.cursor("foo") + assert cur.format == Format.TEXT + cur.close() + + cur = conn.cursor("foo", binary=True) + assert cur.format == Format.BINARY + cur.close() + + +def test_query_params(conn): + with conn.cursor("foo") as cur: + assert cur._query is None + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur._query + assert b"declare" in cur._query.query.lower() + assert b"(1, $1)" in cur._query.query.lower() + assert cur._query.params == [bytes([0, 3])] # 3 as binary int2 + + +def test_binary_cursor_execute(conn): + cur = conn.cursor("foo", binary=True) + cur.execute("select generate_series(1, 2)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + cur.close() + + +def test_execute_binary(conn): + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 2)::int4", binary=True) + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + + cur.execute("select generate_series(1, 1)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + cur.close() + + +def test_binary_cursor_text_override(conn): + cur = conn.cursor("foo", binary=True) + cur.execute("select generate_series(1, 2)", binary=False) + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + assert cur.fetchone() == (2,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"2" + + cur.execute("select generate_series(1, 2)::int4") + assert cur.fetchone() == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + cur.close() + + +def test_close(conn, recwarn): + if conn.info.transaction_status == conn.TransactionStatus.INTRANS: + # connection dirty from previous failure + conn.execute("close foo") + recwarn.clear() + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 10) as bar") + cur.close() + assert cur.closed + + assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_idempotent(conn): + cur = conn.cursor("foo") + cur.execute("select 1") + cur.fetchall() + cur.close() + cur.close() + + +def test_close_broken_conn(conn): + cur = conn.cursor("foo") + conn.close() + cur.close() + assert cur.closed + + +def test_cursor_close_fetchone(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + for _ in range(5): + cur.fetchone() + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchone() + + +def test_cursor_close_fetchmany(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchmany(2)) == 2 + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchmany(2) + + +def test_cursor_close_fetchall(conn): + cur = conn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + cur.execute(query) + assert len(cur.fetchall()) == 10 + + cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + cur.fetchall() + + +def test_close_noop(conn, recwarn): + recwarn.clear() + cur = conn.cursor("foo") + cur.close() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_on_error(conn): + cur = conn.cursor("foo") + cur.execute("select 1") + with pytest.raises(e.ProgrammingError): + conn.execute("wat") + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + cur.close() + + +def test_pgresult(conn): + cur = conn.cursor() + cur.execute("select 1") + assert cur.pgresult + cur.close() + assert not cur.pgresult + + +def test_context(conn, recwarn): + recwarn.clear() + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, 10) as bar") + + assert cur.closed + assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +def test_close_no_clobber(conn): + with pytest.raises(e.DivisionByZero): + with conn.cursor("foo") as cur: + cur.execute("select 1 / %s", (0,)) + cur.fetchall() + + +def test_warn_close(conn, recwarn): + recwarn.clear() + cur = conn.cursor("foo") + cur.execute("select generate_series(1, 10) as bar") + del cur + assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +def test_execute_reuse(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as foo", (3,)) + assert cur.fetchone() == (1,) + + cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world")) + assert cur.fetchone() == ("hello", "world") + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["text"].oid + assert cur.description[1].name == "baz" + + +@pytest.mark.parametrize( + "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"] +) +def test_execute_error(conn, stmt): + cur = conn.cursor("foo") + with pytest.raises(e.ProgrammingError): + cur.execute(stmt) + cur.close() + + +def test_executemany(conn): + cur = conn.cursor("foo") + with pytest.raises(e.NotSupportedError): + cur.executemany("select %s", [(1,), (2,)]) + cur.close() + + +def test_fetchone(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (2,)) + assert cur.fetchone() == (1,) + assert cur.fetchone() == (2,) + assert cur.fetchone() is None + + +def test_fetchmany(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (5,)) + assert cur.fetchmany(3) == [(1,), (2,), (3,)] + assert cur.fetchone() == (4,) + assert cur.fetchmany(3) == [(5,)] + assert cur.fetchmany(3) == [] + + +def test_fetchall(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchall() == [(1,), (2,), (3,)] + assert cur.fetchall() == [] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + assert cur.fetchall() == [(2,), (3,)] + assert cur.fetchall() == [] + + +def test_nextset(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert not cur.nextset() + + +def test_no_result(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar where false", (3,)) + assert len(cur.description) == 1 + assert cur.fetchall() == [] + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_standard_row_factory(conn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + with conn.cursor("foo", row_factory=row_factory) as cur: + cur.execute("select generate_series(1, 5) as bar") + assert getter(cur.fetchone()) == 1 + assert list(map(getter, cur.fetchmany(2))) == [2, 3] + assert list(map(getter, cur.fetchall())) == [4, 5] + + +@pytest.mark.crdb_skip("scroll cursor") +def test_row_factory(conn): + n = 0 + + def my_row_factory(cur): + nonlocal n + n += 1 + return lambda values: [n] + [-v for v in values] + + cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True) + cur.execute("select generate_series(1, 3) as x") + recs = cur.fetchall() + cur.scroll(0, "absolute") + while True: + rec = cur.fetchone() + if not rec: + break + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 + + cur.scroll(0, "absolute") + cur.row_factory = rows.dict_row + assert cur.fetchone() == {"x": 1} + cur.close() + + +def test_rownumber(conn): + cur = conn.cursor("foo") + assert cur.rownumber is None + + cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + cur.fetchone() + assert cur.rownumber == 1 + cur.fetchone() + assert cur.rownumber == 2 + cur.fetchmany(10) + assert cur.rownumber == 12 + cur.fetchall() + assert cur.rownumber == 42 + cur.close() + + +def test_iter(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = list(cur) + assert recs == [(1,), (2,), (3,)] + + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur.fetchone() == (1,) + recs = list(cur) + assert recs == [(2,), (3,)] + + +def test_iter_rownumber(conn): + with conn.cursor("foo") as cur: + cur.execute("select generate_series(1, %s) as bar", (3,)) + for row in cur: + assert cur.rownumber == row[0] + + +def test_itersize(conn, commands): + with conn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + cur.execute("select generate_series(1, %s) as bar", (3,)) + commands.popall() # flush begin and other noise + + list(cur) + cmds = commands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert "fetch forward 2" in cmd.lower() + + +def test_cant_scroll_by_default(conn): + cur = conn.cursor("tmp") + assert cur.scrollable is None + with pytest.raises(e.ProgrammingError): + cur.scroll(0) + cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +def test_scroll(conn): + cur = conn.cursor("tmp", scrollable=True) + cur.execute("select generate_series(0,9)") + cur.scroll(2) + assert cur.fetchone() == (2,) + cur.scroll(2) + assert cur.fetchone() == (5,) + cur.scroll(2, mode="relative") + assert cur.fetchone() == (8,) + cur.scroll(9, mode="absolute") + assert cur.fetchone() == (9,) + + with pytest.raises(ValueError): + cur.scroll(9, mode="wat") + cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +def test_scrollable(conn): + curs = conn.cursor("foo", scrollable=True) + assert curs.scrollable is True + curs.execute("select generate_series(0, 5)") + curs.scroll(5) + for i in range(4, -1, -1): + curs.scroll(-1) + assert i == curs.fetchone()[0] + curs.scroll(-1) + curs.close() + + +def test_non_scrollable(conn): + curs = conn.cursor("foo", scrollable=False) + assert curs.scrollable is False + curs.execute("select generate_series(0, 5)") + curs.scroll(5) + with pytest.raises(e.OperationalError): + curs.scroll(-1) + curs.close() + + +@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}]) +def test_no_hold(conn, kwargs): + with conn.cursor("foo", **kwargs) as curs: + assert curs.withhold is False + curs.execute("select generate_series(0, 2)") + assert curs.fetchone() == (0,) + conn.commit() + with pytest.raises(e.InvalidCursorName): + curs.fetchone() + + +@pytest.mark.crdb_skip("cursor with hold") +def test_hold(conn): + with conn.cursor("foo", withhold=True) as curs: + assert curs.withhold is True + curs.execute("select generate_series(0, 5)") + assert curs.fetchone() == (0,) + conn.commit() + assert curs.fetchone() == (1,) + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +def test_steal_cursor(conn, row_factory): + cur1 = conn.cursor() + cur1.execute("declare test cursor for select generate_series(1, 6) as s") + + cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory)) + # can call fetch without execute + rec = cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 + assert cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert cur2.fetchall() == [(5,), (6,)] + cur2.close() + + +def test_stolen_cursor_close(conn): + cur1 = conn.cursor() + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() + + cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = conn.cursor("test") + cur2.close() diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py new file mode 100644 index 0000000..21b4345 --- /dev/null +++ b/tests/test_server_cursor_async.py @@ -0,0 +1,543 @@ +import pytest + +import psycopg +from psycopg import rows, errors as e +from psycopg.pq import Format + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("server-side cursor"), +] + + +async def test_init_row_factory(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is aconn + assert cur.row_factory is aconn.row_factory + + aconn.row_factory = rows.dict_row + + async with psycopg.AsyncServerCursor(aconn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + async with psycopg.AsyncServerCursor( + aconn, "baz", row_factory=rows.namedtuple_row + ) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +async def test_init_params(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + async with psycopg.AsyncServerCursor( + aconn, "bar", withhold=True, scrollable=False + ) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + +@pytest.mark.crdb_skip("cursor invalid name") +async def test_funny_name(aconn): + cur = aconn.cursor("1-2-3") + await cur.execute("select generate_series(1, 3) as bar") + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + await cur.close() + + +async def test_repr(aconn): + cur = aconn.cursor("my-name") + assert "psycopg.AsyncServerCursor" in str(cur) + assert "my-name" in repr(cur) + await cur.close() + + +async def test_connection(aconn): + cur = aconn.cursor("foo") + assert cur.connection is aconn + await cur.close() + + +async def test_description(aconn): + cur = aconn.cursor("foo") + assert cur.name == "foo" + await cur.execute("select generate_series(1, 10)::int4 as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 + await cur.close() + + +async def test_format(aconn): + cur = aconn.cursor("foo") + assert cur.format == Format.TEXT + await cur.close() + + cur = aconn.cursor("foo", binary=True) + assert cur.format == Format.BINARY + await cur.close() + + +async def test_query_params(aconn): + async with aconn.cursor("foo") as cur: + assert cur._query is None + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur._query is not None + assert b"declare" in cur._query.query.lower() + assert b"(1, $1)" in cur._query.query.lower() + assert cur._query.params == [bytes([0, 3])] # 3 as binary int2 + + +async def test_binary_cursor_execute(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + await cur.close() + + +async def test_execute_binary(aconn): + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 2)::int4", binary=True) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + + await cur.execute("select generate_series(1, 1)") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + await cur.close() + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)", binary=False) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"2" + + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + await cur.close() + + +async def test_close(aconn, recwarn): + if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS: + # connection dirty from previous failure + await aconn.execute("close foo") + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + await cur.close() + assert cur.closed + + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_idempotent(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + await cur.fetchall() + await cur.close() + await cur.close() + + +async def test_close_broken_conn(aconn): + cur = aconn.cursor("foo") + await aconn.close() + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchall() + + +async def test_close_noop(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.close() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_on_error(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + with pytest.raises(e.ProgrammingError): + await aconn.execute("wat") + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + await cur.close() + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_context(aconn, recwarn): + recwarn.clear() + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, 10) as bar") + + assert cur.closed + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_no_clobber(aconn): + with pytest.raises(e.DivisionByZero): + async with aconn.cursor("foo") as cur: + await cur.execute("select 1 / %s", (0,)) + await cur.fetchall() + + +async def test_warn_close(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + del cur + assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +async def test_execute_reuse(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as foo", (3,)) + assert await cur.fetchone() == (1,) + + await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world")) + assert await cur.fetchone() == ("hello", "world") + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["text"].oid + assert cur.description[1].name == "baz" + + +@pytest.mark.parametrize( + "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"] +) +async def test_execute_error(aconn, stmt): + cur = aconn.cursor("foo") + with pytest.raises(e.ProgrammingError): + await cur.execute(stmt) + await cur.close() + + +async def test_executemany(aconn): + cur = aconn.cursor("foo") + with pytest.raises(e.NotSupportedError): + await cur.executemany("select %s", [(1,), (2,)]) + await cur.close() + + +async def test_fetchone(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (2,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchone() == (2,) + assert await cur.fetchone() is None + + +async def test_fetchmany(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (5,)) + assert await cur.fetchmany(3) == [(1,), (2,), (3,)] + assert await cur.fetchone() == (4,) + assert await cur.fetchmany(3) == [(5,)] + assert await cur.fetchmany(3) == [] + + +async def test_fetchall(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert await cur.fetchall() == [] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchall() == [(2,), (3,)] + assert await cur.fetchall() == [] + + +async def test_nextset(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert not cur.nextset() + + +async def test_no_result(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar where false", (3,)) + assert len(cur.description) == 1 + assert (await cur.fetchall()) == [] + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_standard_row_factory(aconn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + async with aconn.cursor("foo", row_factory=row_factory) as cur: + await cur.execute("select generate_series(1, 5) as bar") + assert getter(await cur.fetchone()) == 1 + assert list(map(getter, await cur.fetchmany(2))) == [2, 3] + assert list(map(getter, await cur.fetchall())) == [4, 5] + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_row_factory(aconn): + n = 0 + + def my_row_factory(cur): + nonlocal n + n += 1 + return lambda values: [n] + [-v for v in values] + + cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True) + await cur.execute("select generate_series(1, 3) as x") + recs = await cur.fetchall() + await cur.scroll(0, "absolute") + while True: + rec = await cur.fetchone() + if not rec: + break + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 + + await cur.scroll(0, "absolute") + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"x": 1} + await cur.close() + + +async def test_rownumber(aconn): + cur = aconn.cursor("foo") + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + await cur.fetchall() + assert cur.rownumber == 42 + await cur.close() + + +async def test_iter(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(1,), (2,), (3,)] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(2,), (3,)] + + +async def test_iter_rownumber(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + async for row in cur: + assert cur.rownumber == row[0] + + +async def test_itersize(aconn, acommands): + async with aconn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + await cur.execute("select generate_series(1, %s) as bar", (3,)) + acommands.popall() # flush begin and other noise + + async for rec in cur: + pass + cmds = acommands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert "fetch forward 2" in cmd.lower() + + +async def test_cant_scroll_by_default(aconn): + cur = aconn.cursor("tmp") + assert cur.scrollable is None + with pytest.raises(e.ProgrammingError): + await cur.scroll(0) + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scroll(aconn): + cur = aconn.cursor("tmp", scrollable=True) + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + + with pytest.raises(ValueError): + await cur.scroll(9, mode="wat") + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=True) + assert curs.scrollable is True + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + for i in range(4, -1, -1): + await curs.scroll(-1) + assert i == (await curs.fetchone())[0] + await curs.scroll(-1) + await curs.close() + + +async def test_non_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=False) + assert curs.scrollable is False + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + with pytest.raises(e.OperationalError): + await curs.scroll(-1) + await curs.close() + + +@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}]) +async def test_no_hold(aconn, kwargs): + async with aconn.cursor("foo", **kwargs) as curs: + assert curs.withhold is False + await curs.execute("select generate_series(0, 2)") + assert await curs.fetchone() == (0,) + await aconn.commit() + with pytest.raises(e.InvalidCursorName): + await curs.fetchone() + + +@pytest.mark.crdb_skip("cursor with hold") +async def test_hold(aconn): + async with aconn.cursor("foo", withhold=True) as curs: + assert curs.withhold is True + await curs.execute("select generate_series(0, 5)") + assert await curs.fetchone() == (0,) + await aconn.commit() + assert await curs.fetchone() == (1,) + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +async def test_steal_cursor(aconn, row_factory): + cur1 = aconn.cursor() + await cur1.execute( + "declare test cursor without hold for select generate_series(1, 6) as s" + ) + + cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory)) + # can call fetch without execute + rec = await cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 + assert await cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert await cur2.fetchall() == [(5,), (6,)] + await cur2.close() + + +async def test_stolen_cursor_close(aconn): + cur1 = aconn.cursor() + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() + + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 0000000..42b6c63 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,604 @@ +# test_sql.py - tests for the psycopg2.sql module + +# Copyright (C) 2020 The Psycopg Team + +import re +import datetime as dt + +import pytest + +from psycopg import pq, sql, ProgrammingError +from psycopg.adapt import PyFormat +from psycopg._encodings import py2pgenc +from psycopg.types import TypeInfo +from psycopg.types.string import StrDumper + +from .utils import eur +from .fix_crdb import crdb_encoding, crdb_scs_off + + +@pytest.mark.parametrize( + "obj, quoted", + [ + ("foo\\bar", " E'foo\\\\bar'"), + ("hello", "'hello'"), + (42, "42"), + (True, "true"), + (None, "NULL"), + ], +) +def test_quote(obj, quoted): + assert sql.quote(obj) == quoted + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_quote_roundtrip(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + + for i in range(1, 256): + want = chr(i) + quoted = sql.quote(want) + got = conn.execute(f"select {quoted}::text").fetchone()[0] + assert want == got + + # No "nonstandard use of \\ in a string literal" warning + assert not messages, f"error with {want!r}" + + +@pytest.mark.parametrize("dummy", [crdb_scs_off("off")]) +def test_quote_stable_despite_deranged_libpq(conn, dummy): + # Verify the libpq behaviour of PQescapeString using the last setting seen. + # Check that we are not affected by it. + good_str = " E'\\\\'" + good_bytes = " E'\\\\000'::bytea" + conn.execute("set standard_conforming_strings to on") + assert pq.Escaping().escape_string(b"\\") == b"\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\000" + assert sql.quote(b"\x00") == good_bytes + + conn.execute("set standard_conforming_strings to off") + assert pq.Escaping().escape_string(b"\\") == b"\\\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000" + assert sql.quote(b"\x00") == good_bytes + + # Verify that the good values are actually good + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute("set escape_string_warning to on") + for scs in ("on", "off"): + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.execute(f"select {good_str}, {good_bytes}::bytea") + assert cur.fetchone() == ("\\", b"\x00") + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +class TestSqlFormat: + def test_pos(self, conn): + s = sql.SQL("select {} from {}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_pos_spec(self, conn): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier("table"), sql.Identifier("field") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_dict(self, conn): + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier("field"), t=sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_compose_literal(self, conn): + s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) + s1 = s.as_string(conn) + assert s1 == "select '2016-12-31'::date;" + + def test_compose_empty(self, conn): + s = sql.SQL("select foo;").format() + s1 = s.as_string(conn) + assert s1 == "select foo;" + + def test_percent_escape(self, conn): + s = sql.SQL("42 % {0}").format(sql.Literal(7)) + s1 = s.as_string(conn) + assert s1 == "42 % 7" + + def test_braces_escape(self, conn): + s = sql.SQL("{{{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{7}" + s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{1,7}" + + def test_compose_badnargs(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format() + + def test_compose_badnargs_auto(self): + with pytest.raises(IndexError): + sql.SQL("select {};").format() + with pytest.raises(ValueError): + sql.SQL("select {} {1};").format(10, 20) + with pytest.raises(ValueError): + sql.SQL("select {0} {};").format(10, 20) + + def test_compose_bad_args_type(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format(a=10) + with pytest.raises(KeyError): + sql.SQL("select {x};").format(10) + + def test_no_modifiers(self): + with pytest.raises(ValueError): + sql.SQL("select {a!r};").format(a=10) + with pytest.raises(ValueError): + sql.SQL("select {a:<};").format(a=10) + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + s = sql.SQL("select {0};").format(sql.Literal(Foo())) + with pytest.raises(ProgrammingError): + s.as_string(conn) + + def test_auto_literal(self, conn): + s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1)) + assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date" + + def test_execute(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + (10, "a", "b", "c"), + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c")] + + def test_executemany(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + [(10, "a", "b", "c"), (20, "d", "e", "f")], + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")] + + @pytest.mark.crdb_skip("copy") + def test_copy(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + + with cur.copy( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ), + ) as copy: + copy.write_row((10, "a", "b", "c")) + copy.write_row((20, "d", "e", "f")) + + with cur.copy( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ) + ) as copy: + assert list(copy) == [b"c\n", b"f\n"] + + +class TestIdentifier: + def test_class(self): + assert issubclass(sql.Identifier, sql.Composable) + + def test_init(self): + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier) + with pytest.raises(TypeError): + sql.Identifier() + with pytest.raises(TypeError): + sql.Identifier(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.Identifier(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self): + obj = sql.Identifier("fo'o") + assert repr(obj) == 'Identifier("fo\'o")' + assert repr(obj) == str(obj) + + obj = sql.Identifier("fo'o", 'ba"r') + assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')" + assert repr(obj) == str(obj) + + def test_eq(self): + assert sql.Identifier("foo") == sql.Identifier("foo") + assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar") + assert sql.Identifier("foo") != sql.Identifier("bar") + assert sql.Identifier("foo") != "foo" + assert sql.Identifier("foo") != sql.SQL("foo") + + @pytest.mark.parametrize( + "args, want", + [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ], + ) + def test_as_string(self, conn, args, want): + assert sql.Identifier(*args).as_string(conn) == want + + @pytest.mark.parametrize( + "args, want, enc", + [ + crdb_encoding(("foo",), '"foo"', "ascii"), + crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), + crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), + ], + ) + def test_as_bytes(self, conn, args, want, enc): + want = want.encode(enc) + conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") + assert sql.Identifier(*args).as_bytes(conn) == want + + def test_join(self): + assert not hasattr(sql.Identifier("foo"), "join") + + +class TestLiteral: + def test_class(self): + assert issubclass(sql.Literal, sql.Composable) + + def test_init(self): + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal(b"foo"), sql.Literal) + assert isinstance(sql.Literal(42), sql.Literal) + assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal) + + def test_repr(self): + assert repr(sql.Literal("foo")) == "Literal('foo')" + assert str(sql.Literal("foo")) == "Literal('foo')" + + def test_as_string(self, conn): + assert sql.Literal(None).as_string(conn) == "NULL" + assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" + assert sql.Literal(42).as_string(conn) == "42" + assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" + + def test_as_bytes(self, conn): + assert sql.Literal(None).as_bytes(conn) == b"NULL" + assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" + assert sql.Literal(42).as_bytes(conn) == b"42" + assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode(encoding) + + def test_eq(self): + assert sql.Literal("foo") == sql.Literal("foo") + assert sql.Literal("foo") != sql.Literal("bar") + assert sql.Literal("foo") != "foo" + assert sql.Literal("foo") != sql.SQL("foo") + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + with pytest.raises(ProgrammingError): + sql.Literal(Foo()).as_string(conn) + + def test_array(self, conn): + assert ( + sql.Literal([dt.date(2000, 1, 1)]).as_string(conn) + == "'{2000-01-01}'::date[]" + ) + + def test_short_name_builtin(self, conn): + assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time" + assert ( + sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn) + == "'2000-01-01 00:00:00'::timestamp" + ) + assert ( + sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn) + == "'{\"2000-01-01 00:00:00\"}'::timestamp[]" + ) + + def test_text_literal(self, conn): + conn.adapters.register_dumper(str, StrDumper) + assert sql.Literal("foo").as_string(conn) == "'foo'" + + @pytest.mark.crdb_skip("composite") # create type, actually + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"]) + def test_invalid_name(self, conn, name): + conn.execute( + f""" + set client_encoding to utf8; + create type "{name}"; + create function invin(cstring) returns "{name}" + language internal immutable strict as 'textin'; + create function invout("{name}") returns cstring + language internal immutable strict as 'textout'; + create type "{name}" (input=invin, output=invout, like=text); + """ + ) + info = TypeInfo.fetch(conn, f'"{name}"') + + class InvDumper(StrDumper): + oid = info.oid + + def dump(self, obj): + rv = super().dump(obj) + return b"%s-inv" % rv + + info.register(conn) + conn.adapters.register_dumper(str, InvDumper) + + assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format("hello")) + assert cur.fetchone()[0] == "hello-inv" + + assert ( + sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]" + ) + cur = conn.execute(sql.SQL("select {}").format(["hello"])) + assert cur.fetchone()[0] == ["hello-inv"] + + +class TestSQL: + def test_class(self): + assert issubclass(sql.SQL, sql.Composable) + + def test_init(self): + assert isinstance(sql.SQL("foo"), sql.SQL) + assert isinstance(sql.SQL("foo"), sql.SQL) + with pytest.raises(TypeError): + sql.SQL(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.SQL(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self, conn): + assert repr(sql.SQL("foo")) == "SQL('foo')" + assert str(sql.SQL("foo")) == "SQL('foo')" + assert sql.SQL("foo").as_string(conn) == "foo" + + def test_eq(self): + assert sql.SQL("foo") == sql.SQL("foo") + assert sql.SQL("foo") != sql.SQL("bar") + assert sql.SQL("foo") != "foo" + assert sql.SQL("foo") != sql.Literal("foo") + + def test_sum(self, conn): + obj = sql.SQL("foo") + sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_sum_inplace(self, conn): + obj = sql.SQL("f") + sql.SQL("oo") + obj += sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_multiply(self, conn): + obj = sql.SQL("foo") * 3 + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foofoofoo" + + def test_join(self, conn): + obj = sql.SQL(", ").join( + [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join( + sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]) + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join([]) + assert obj == sql.Composed([]) + + def test_as_string(self, conn): + assert sql.SQL("foo").as_string(conn) == "foo" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes(self, conn, encoding): + if encoding: + conn.execute(f"set client_encoding to {encoding}") + + assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + + +class TestComposed: + def test_class(self): + assert issubclass(sql.Composed, sql.Composable) + + def test_repr(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])""" + assert str(obj) == repr(obj) + + def test_eq(self): + L = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + assert sql.Composed(L) == sql.Composed(list(L)) + assert sql.Composed(L) != L + assert sql.Composed(L) != sql.Composed(l2) + + def test_join(self, conn): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'foo', \"b'ar\"" + + def test_auto_literal(self, conn): + obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date" + + def test_sum(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_sum_inplace(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + it = iter(obj) + i = next(it) + assert i == sql.SQL("foo") + i = next(it) + assert i == sql.SQL("bar") + with pytest.raises(StopIteration): + next(it) + + def test_as_string(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_string(conn) == "foobar" + + def test_as_bytes(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_bytes(conn) == b"foobar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)]) + conn.execute(f"set client_encoding to {encoding}") + assert obj.as_bytes(conn) == ("foo" + eur).encode(encoding) + + +class TestPlaceholder: + def test_class(self): + assert issubclass(sql.Placeholder, sql.Composable) + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_format(self, conn, format): + ph = sql.Placeholder(format=format) + add = f"format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder({add})" + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_name_format(self, conn, format): + ph = sql.Placeholder("foo", format=format) + add = f", format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder('foo'{add})" + + def test_bad_name(self): + with pytest.raises(ValueError): + sql.Placeholder(")") + + def test_eq(self): + assert sql.Placeholder("foo") == sql.Placeholder("foo") + assert sql.Placeholder("foo") != sql.Placeholder("bar") + assert sql.Placeholder("foo") != "foo" + assert sql.Placeholder() == sql.Placeholder() + assert sql.Placeholder("foo") != sql.Placeholder() + assert sql.Placeholder("foo") != sql.Literal("foo") + + @pytest.mark.parametrize("format", PyFormat) + def test_as_string(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_string(conn) == f"%{format.value}" + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_string(conn) == f"%(foo){format.value}" + + @pytest.mark.parametrize("format", PyFormat) + def test_as_bytes(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii") + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii") + + +class TestValues: + def test_null(self, conn): + assert isinstance(sql.NULL, sql.SQL) + assert sql.NULL.as_string(conn) == "NULL" + + def test_default(self, conn): + assert isinstance(sql.DEFAULT, sql.SQL) + assert sql.DEFAULT.as_string(conn) == "DEFAULT" + + +def no_e(s): + """Drop an eventual E from E'' quotes""" + if isinstance(s, memoryview): + s = bytes(s) + + if isinstance(s, str): + return re.sub(r"\bE'", "'", s) + elif isinstance(s, bytes): + return re.sub(rb"\bE'", b"'", s) + else: + raise TypeError(f"not dealing with {type(s).__name__}: {s}") diff --git a/tests/test_tpc.py b/tests/test_tpc.py new file mode 100644 index 0000000..91a04e0 --- /dev/null +++ b/tests/test_tpc.py @@ -0,0 +1,325 @@ +import pytest + +import psycopg +from psycopg.pq import TransactionStatus + +pytestmark = pytest.mark.crdb_skip("2-phase commit") + + +def test_tpc_disabled(conn, pipeline): + val = int(conn.execute("show max_prepared_transactions").fetchone()[0]) + if val: + pytest.skip("prepared transactions enabled") + + conn.rollback() + conn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + conn.tpc_prepare() + + +class TestTPC: + def test_tpc_commit(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_commit_recovered(self, conn_cls, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + with conn_cls.connect(dsn) as conn: + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_commit(xid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + def test_tpc_rollback(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_one_phase(self, conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_rollback() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_tpc_rollback_recovered(self, conn_cls, conn, dsn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + conn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + with conn_cls.connect(dsn) as conn: + xid = conn.xid(1, "gtrid", "bqual") + conn.tpc_rollback(xid) + assert conn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + def test_status_after_recover(self, conn, tpc): + assert conn.info.transaction_status == TransactionStatus.IDLE + conn.tpc_recover() + assert conn.info.transaction_status == TransactionStatus.IDLE + + cur = conn.cursor() + cur.execute("select 1") + assert conn.info.transaction_status == TransactionStatus.INTRANS + conn.tpc_recover() + assert conn.info.transaction_status == TransactionStatus.INTRANS + + def test_recovered_xids(self, conn, tpc): + # insert a few test xns + conn.autocommit = True + cur = conn.cursor() + cur.execute("begin; prepare transaction '1-foo'") + cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (conn.info.dbname,), + ) + okvals = cur.fetchall() + okvals.sort() + + xids = conn.tpc_recover() + xids = [xid for xid in xids if xid.database == conn.info.dbname] + xids.sort(key=lambda x: x.gtrid) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + def test_xid_encoding(self, conn, tpc): + xid = conn.xid(42, "gtrid", "bqual") + conn.tpc_begin(xid) + conn.tpc_prepare() + + cur = conn.cursor() + cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (conn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == cur.fetchone()[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + def test_xid_roundtrip(self, conn_cls, conn, dsn, tpc, fid, gtrid, bqual): + xid = conn.xid(fid, gtrid, bqual) + conn.tpc_begin(xid) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] + + assert len(xids) == 1 + xid = xids[0] + conn.tpc_rollback(xid) + + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + def test_unparsed_roundtrip(self, conn_cls, conn, dsn, tpc, tid): + conn.tpc_begin(tid) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname] + + assert len(xids) == 1 + xid = xids[0] + conn.tpc_rollback(xid) + + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + def test_xid_unicode(self, conn_cls, conn, dsn, tpc): + x1 = conn.xid(10, "uni", "code") + conn.tpc_begin(x1) + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + def test_xid_unicode_unparsed(self, conn_cls, conn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + conn.execute("set client_encoding to utf8") + conn.commit() + + conn.tpc_begin("transaction-id") + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + def test_cancel_fails_prepared(self, conn, tpc): + conn.tpc_begin("cancel") + conn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + conn.cancel() + + def test_tpc_recover_non_dbapi_connection(self, conn_cls, conn, dsn, tpc): + conn.row_factory = psycopg.rows.dict_row + conn.tpc_begin("dict-connection") + conn.tpc_prepare() + conn.close() + + with conn_cls.connect(dsn) as conn: + xids = conn.tpc_recover() + xid = [x for x in xids if x.database == conn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None + + +class TestXidObject: + def test_xid_construction(self): + x1 = psycopg.Xid(74, "foo", "bar") + 74 == x1.format_id + "foo" == x1.gtrid + "bar" == x1.bqual + + def test_xid_from_string(self): + x2 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + 42 == x2.format_id + "gtrid" == x2.gtrid + "bqual" == x2.bqual + + x3 = psycopg.Xid.from_string("99_xxx_yyy") + None is x3.format_id + "99_xxx_yyy" == x3.gtrid + None is x3.bqual + + def test_xid_to_string(self): + x1 = psycopg.Xid.from_string("42_Z3RyaWQ=_YnF1YWw=") + str(x1) == "42_Z3RyaWQ=_YnF1YWw=" + + x2 = psycopg.Xid.from_string("99_xxx_yyy") + str(x2) == "99_xxx_yyy" diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py new file mode 100644 index 0000000..a409a2e --- /dev/null +++ b/tests/test_tpc_async.py @@ -0,0 +1,310 @@ +import pytest + +import psycopg +from psycopg.pq import TransactionStatus + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("2-phase commit"), +] + + +async def test_tpc_disabled(aconn, apipeline): + cur = await aconn.execute("show max_prepared_transactions") + val = int((await cur.fetchone())[0]) + if val: + pytest.skip("prepared transactions enabled") + + await aconn.rollback() + await aconn.tpc_begin("x") + with pytest.raises(psycopg.NotSupportedError): + await aconn.tpc_prepare() + + +class TestTPC: + async def test_tpc_commit(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_commit() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_commit_recovered(self, aconn_cls, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + async with await aconn_cls.connect(dsn) as aconn: + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_commit(xid) + assert aconn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 + + async def test_tpc_rollback(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_rollback')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_one_phase(self, aconn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_rollback() + assert aconn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_tpc_rollback_recovered(self, aconn_cls, aconn, dsn, tpc): + xid = aconn.xid(1, "gtrid", "bqual") + assert aconn.info.transaction_status == TransactionStatus.IDLE + + await aconn.tpc_begin(xid) + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + cur = aconn.cursor() + await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + await aconn.tpc_prepare() + await aconn.close() + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + async with await aconn_cls.connect(dsn) as aconn: + xid = aconn.xid(1, "gtrid", "bqual") + await aconn.tpc_rollback(xid) + assert aconn.info.transaction_status == TransactionStatus.IDLE + + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + async def test_status_after_recover(self, aconn, tpc): + assert aconn.info.transaction_status == TransactionStatus.IDLE + await aconn.tpc_recover() + assert aconn.info.transaction_status == TransactionStatus.IDLE + + cur = aconn.cursor() + await cur.execute("select 1") + assert aconn.info.transaction_status == TransactionStatus.INTRANS + await aconn.tpc_recover() + assert aconn.info.transaction_status == TransactionStatus.INTRANS + + async def test_recovered_xids(self, aconn, tpc): + # insert a few test xns + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("begin; prepare transaction '1-foo'") + await cur.execute("begin; prepare transaction '2-bar'") + + # read the values to return + await cur.execute( + """ + select gid, prepared, owner, database from pg_prepared_xacts + where database = %s + """, + (aconn.info.dbname,), + ) + okvals = await cur.fetchall() + okvals.sort() + + xids = await aconn.tpc_recover() + xids = [xid for xid in xids if xid.database == aconn.info.dbname] + xids.sort(key=lambda x: x.gtrid) + + # check the values returned + assert len(okvals) == len(xids) + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): + assert xid.gtrid == gid + assert xid.prepared == prepared + assert xid.owner == owner + assert xid.database == database + + async def test_xid_encoding(self, aconn, tpc): + xid = aconn.xid(42, "gtrid", "bqual") + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + + cur = aconn.cursor() + await cur.execute( + "select gid from pg_prepared_xacts where database = %s", + (aconn.info.dbname,), + ) + assert "42_Z3RyaWQ=_YnF1YWw=" == (await cur.fetchone())[0] + + @pytest.mark.parametrize( + "fid, gtrid, bqual", + [ + (0, "", ""), + (42, "gtrid", "bqual"), + (0x7FFFFFFF, "x" * 64, "y" * 64), + ], + ) + async def test_xid_roundtrip(self, aconn_cls, aconn, dsn, tpc, fid, gtrid, bqual): + xid = aconn.xid(fid, gtrid, bqual) + await aconn.tpc_begin(xid) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + await aconn.tpc_rollback(xid) + + assert xid.format_id == fid + assert xid.gtrid == gtrid + assert xid.bqual == bqual + + @pytest.mark.parametrize( + "tid", + [ + "", + "hello, world!", + "x" * 199, # PostgreSQL's limit in transaction id length + ], + ) + async def test_unparsed_roundtrip(self, aconn_cls, aconn, dsn, tpc, tid): + await aconn.tpc_begin(tid) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ] + assert len(xids) == 1 + xid = xids[0] + await aconn.tpc_rollback(xid) + + assert xid.format_id is None + assert xid.gtrid == tid + assert xid.bqual is None + + async def test_xid_unicode(self, aconn_cls, aconn, dsn, tpc): + x1 = aconn.xid(10, "uni", "code") + await aconn.tpc_begin(x1) + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xid = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ][0] + + assert 10 == xid.format_id + assert "uni" == xid.gtrid + assert "code" == xid.bqual + + async def test_xid_unicode_unparsed(self, aconn_cls, aconn, dsn, tpc): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check unicode is accepted as type. + await aconn.execute("set client_encoding to utf8") + await aconn.commit() + + await aconn.tpc_begin("transaction-id") + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xid = [ + x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname + ][0] + + assert xid.format_id is None + assert xid.gtrid == "transaction-id" + assert xid.bqual is None + + async def test_cancel_fails_prepared(self, aconn, tpc): + await aconn.tpc_begin("cancel") + await aconn.tpc_prepare() + with pytest.raises(psycopg.ProgrammingError): + aconn.cancel() + + async def test_tpc_recover_non_dbapi_connection(self, aconn_cls, aconn, dsn, tpc): + aconn.row_factory = psycopg.rows.dict_row + await aconn.tpc_begin("dict-connection") + await aconn.tpc_prepare() + await aconn.close() + + async with await aconn_cls.connect(dsn) as aconn: + xids = await aconn.tpc_recover() + xid = [x for x in xids if x.database == aconn.info.dbname][0] + + assert xid.format_id is None + assert xid.gtrid == "dict-connection" + assert xid.bqual is None diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..9391e00 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,796 @@ +import sys +import logging +from threading import Thread, Event + +import pytest + +import psycopg +from psycopg import Rollback +from psycopg import errors as e + +# TODOCRDB: is this the expected behaviour? +crdb_skip_external_observer = pytest.mark.crdb( + "skip", reason="deadlock on observer connection" +) + + +@pytest.fixture +def conn(conn, pipeline): + return conn + + +@pytest.fixture(autouse=True) +def create_test_table(svcconn): + """Creates a table called 'test_table' for use in tests.""" + cur = svcconn.cursor() + cur.execute("drop table if exists test_table") + cur.execute("create table test_table (id text primary key)") + yield + cur.execute("drop table test_table") + + +def insert_row(conn, value): + sql = "INSERT INTO test_table VALUES (%s)" + if isinstance(conn, psycopg.Connection): + conn.cursor().execute(sql, (value,)) + else: + + async def f(): + cur = conn.cursor() + await cur.execute(sql, (value,)) + + return f() + + +def inserted(conn): + """Return the values inserted in the test table.""" + sql = "SELECT * FROM test_table" + if isinstance(conn, psycopg.Connection): + rows = conn.cursor().execute(sql).fetchall() + return set(v for (v,) in rows) + else: + + async def f(): + cur = conn.cursor() + await cur.execute(sql) + rows = await cur.fetchall() + return set(v for (v,) in rows) + + return f() + + +def in_transaction(conn): + if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE: + return False + elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS: + return True + else: + assert False, conn.pgconn.transaction_status + + +def get_exc_info(exc): + """Return the exc info for an exception or a success if exc is None""" + if not exc: + return (None,) * 3 + try: + raise exc + except exc: + return sys.exc_info() + + +class ExpectedException(Exception): + pass + + +def test_basic(conn, pipeline): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert not in_transaction(conn) + with conn.transaction(): + if pipeline: + pipeline.sync() + assert in_transaction(conn) + assert not in_transaction(conn) + + +def test_exposes_associated_connection(conn): + """Transaction exposes its connection as a read-only property.""" + with conn.transaction() as tx: + assert tx.connection is conn + with pytest.raises(AttributeError): + tx.connection = conn + + +def test_exposes_savepoint_name(conn): + """Transaction exposes its savepoint name as a read-only property.""" + with conn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +def test_cant_reenter(conn): + with conn.transaction() as tx: + pass + + with pytest.raises(TypeError): + with tx: + pass + + +def test_begins_on_enter(conn, pipeline): + """Transaction does not begin until __enter__() is called.""" + tx = conn.transaction() + assert not in_transaction(conn) + with tx: + if pipeline: + pipeline.sync() + assert in_transaction(conn) + assert not in_transaction(conn) + + +def test_commit_on_successful_exit(conn): + """Changes are committed on successful exit from the `with` block.""" + with conn.transaction(): + insert_row(conn, "foo") + + assert not in_transaction(conn) + assert inserted(conn) == {"foo"} + + +def test_rollback_on_exception_exit(conn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "foo") + raise ExpectedException("This discards the insert") + + assert not in_transaction(conn) + assert not inserted(conn) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +def test_context_inerror_rollback_no_clobber(conn_cls, conn, pipeline, dsn, caplog): + if pipeline: + # Only 'conn' is possibly in pipeline mode, but the transaction and + # checks are on 'conn2'. + pytest.skip("not applicable") + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + with conn_cls.connect(dsn) as conn2: + with conn2.transaction(): + conn2.execute("select 1") + conn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + conn = conn_cls.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + with conn.transaction(): + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + conn.close() + + +def test_interaction_dbapi_transaction(conn): + insert_row(conn, "foo") + + with conn.transaction(): + insert_row(conn, "bar") + raise Rollback + + with conn.transaction(): + insert_row(conn, "baz") + + assert in_transaction(conn) + conn.commit() + assert inserted(conn) == {"foo", "baz"} + + +def test_prohibits_use_of_commit_rollback_autocommit(conn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + conn.autocommit = False + conn.commit() + conn.rollback() + + with conn.transaction(): + with pytest.raises(e.ProgrammingError): + conn.autocommit = False + with pytest.raises(e.ProgrammingError): + conn.commit() + with pytest.raises(e.ProgrammingError): + conn.rollback() + + conn.autocommit = False + conn.commit() + conn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +def test_preserves_autocommit(conn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + conn.autocommit = autocommit + with conn.transaction(): + assert conn.autocommit is autocommit + assert conn.autocommit is autocommit + + +def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + conn.autocommit = False + assert not in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert not in_transaction(conn) + + # Changes committed + assert inserted(conn) == {"new"} + assert inserted(svcconn) == {"new"} + + +def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + conn.autocommit = False + assert not in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert not in_transaction(conn) + + # Changes discarded + assert not inserted(conn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_autocommit_off_and_tx_in_progress_successful_exit(conn, pipeline, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + if pipeline: + pipeline.sync() + assert in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert in_transaction(conn) + assert inserted(conn) == {"prior", "new"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_autocommit_off_and_tx_in_progress_exception_exit(conn, pipeline, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + if pipeline: + pipeline.sync() + assert in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert in_transaction(conn) + assert inserted(conn) == {"prior"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction(): + insert_row(conn, "inner") + insert_row(conn, "outer-after") + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} + + +def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + insert_row(conn, "outer-after") + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +def test_nested_three_levels_successful_exit(conn, svcconn): + """Exercise management of more than one savepoint.""" + with conn.transaction(): # BEGIN + insert_row(conn, "one") + with conn.transaction(): # SAVEPOINT s1 + insert_row(conn, "two") + with conn.transaction(): # SAVEPOINT s2 + insert_row(conn, "three") + assert not in_transaction(conn) + assert inserted(conn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} + + +def test_named_savepoint_escapes_savepoint_name(conn): + with conn.transaction("s-1"): + pass + with conn.transaction("s1; drop table students"): + pass + + +def test_named_savepoints_successful_exit(conn, commands): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + # Case 1 + # Using Transaction explicitly because conn.transaction() enters the contetx + assert not commands + with conn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + assert commands.popall() == ["COMMIT"] + + # Case 1 (with a transaction already started) + conn.cursor().execute("select 1") + assert commands.popall() == ["BEGIN"] + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_1"'] + assert tx.savepoint_name == "_pg3_1" + assert commands.popall() == ['RELEASE "_pg3_1"'] + conn.rollback() + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with conn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name provided) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with conn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + assert commands.popall() == ['RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + assert commands.popall() == ['RELEASE "_pg3_2"'] + assert commands.popall() == ["COMMIT"] + + +def test_named_savepoints_exception_exit(conn, commands): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + # Case 1 + with pytest.raises(ExpectedException): + with conn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with pytest.raises(ExpectedException): + with conn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 3 (with savepoint name provided) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + with conn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + raise ExpectedException + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + with conn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + raise ExpectedException + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] + assert commands.popall() == ["COMMIT"] + + +def test_named_savepoints_with_repeated_names_works(conn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("sp"): + insert_row(conn, "tx1") + with conn.transaction("sp"): + insert_row(conn, "tx2") + with conn.transaction("sp"): + insert_row(conn, "tx3") + assert inserted(conn) == {"tx1", "tx2", "tx3"} + + # Works correctly if one level of inner transaction is rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1", force_rollback=True): + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1") as tx2: + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + raise Rollback(tx2) + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} + + +def test_force_rollback_successful_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + assert not inserted(conn) + assert not inserted(svcconn) + + +def test_force_rollback_exception_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + raise ExpectedException() + assert not inserted(conn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +def test_explicit_rollback_discards_changes(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + + def assert_no_rows(): + assert not inserted(conn) + assert not inserted(svcconn) + + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback + assert_no_rows() + + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback() + assert_no_rows() + + with conn.transaction() as tx: + insert_row(conn, "foo") + raise Rollback(tx) + assert_no_rows() + + +@crdb_skip_external_observer +def test_explicit_rollback_outer_tx_unaffected(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + with conn.transaction(): + insert_row(conn, "before") + with conn.transaction(): + insert_row(conn, "during") + raise Rollback + assert in_transaction(conn) + assert not inserted(svcconn) + insert_row(conn, "after") + assert inserted(conn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} + + +def test_explicit_rollback_of_outer_transaction(conn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + with conn.transaction() as outer_tx: + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert not inserted(conn) + + +@crdb_skip_external_observer +def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction() as tx_enclosing: + insert_row(conn, "enclosing") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(tx_enclosing) + insert_row(conn, "outer-after") + + assert inserted(conn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +def test_str(conn, pipeline): + with conn.transaction() as tx: + if pipeline: + assert "[INTRANS, pipeline=ON]" in str(tx) + else: + assert "[INTRANS]" in str(tx) + assert "(active)" in str(tx) + assert "'" not in str(tx) + with conn.transaction("wat") as tx2: + if pipeline: + assert "[INTRANS, pipeline=ON]" in str(tx2) + else: + assert "[INTRANS]" in str(tx2) + assert "'wat'" in str(tx2) + + if pipeline: + assert "[IDLE, pipeline=ON]" in str(tx) + else: + assert "[IDLE]" in str(tx) + assert "(terminated)" in str(tx) + + with pytest.raises(ZeroDivisionError): + with conn.transaction() as tx: + 1 / 0 + + assert "(terminated)" in str(tx) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_exit(conn, exit_error): + conn.autocommit = True + + t1 = conn.transaction() + t1.__enter__() + + t2 = conn.transaction() + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_implicit_begin(conn, exit_error): + conn.execute("select 1") + + t1 = conn.transaction() + t1.__enter__() + + t2 = conn.transaction() + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_exit_same_name(conn, exit_error): + conn.autocommit = True + + t1 = conn.transaction("save") + t1.__enter__() + t2 = conn.transaction("save") + t2.__enter__() + + with pytest.raises(e.ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("what", ["commit", "rollback", "error"]) +def test_concurrency(conn, what): + conn.autocommit = True + + evs = [Event() for i in range(3)] + + def worker(unlock, wait_on): + with pytest.raises(e.ProgrammingError) as ex: + with conn.transaction(): + unlock.set() + wait_on.wait() + conn.execute("select 1") + + if what == "error": + 1 / 0 + elif what == "rollback": + raise Rollback() + else: + assert what == "commit" + + if what == "error": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, ZeroDivisionError) + elif what == "rollback": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, Rollback) + else: + assert "transaction commit" in str(ex.value) + + # Start a first transaction in a thread + t1 = Thread(target=worker, kwargs={"unlock": evs[0], "wait_on": evs[1]}) + t1.start() + evs[0].wait() + + # Start a nested transaction in a thread + t2 = Thread(target=worker, kwargs={"unlock": evs[1], "wait_on": evs[2]}) + t2.start() + + # Terminate the first transaction before the second does + t1.join() + evs[2].set() + t2.join() diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py new file mode 100644 index 0000000..55e1c9c --- /dev/null +++ b/tests/test_transaction_async.py @@ -0,0 +1,743 @@ +import asyncio +import logging + +import pytest + +from psycopg import Rollback +from psycopg import errors as e +from psycopg._compat import create_task + +from .test_transaction import in_transaction, insert_row, inserted, get_exc_info +from .test_transaction import ExpectedException, crdb_skip_external_observer +from .test_transaction import create_test_table # noqa # autouse fixture + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def aconn(aconn, apipeline): + return aconn + + +async def test_basic(aconn, apipeline): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert not in_transaction(aconn) + async with aconn.transaction(): + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_exposes_associated_connection(aconn): + """Transaction exposes its connection as a read-only property.""" + async with aconn.transaction() as tx: + assert tx.connection is aconn + with pytest.raises(AttributeError): + tx.connection = aconn + + +async def test_exposes_savepoint_name(aconn): + """Transaction exposes its savepoint name as a read-only property.""" + async with aconn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +async def test_cant_reenter(aconn): + async with aconn.transaction() as tx: + pass + + with pytest.raises(TypeError): + async with tx: + pass + + +async def test_begins_on_enter(aconn, apipeline): + """Transaction does not begin until __enter__() is called.""" + tx = aconn.transaction() + assert not in_transaction(aconn) + async with tx: + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_commit_on_successful_exit(aconn): + """Changes are committed on successful exit from the `with` block.""" + async with aconn.transaction(): + await insert_row(aconn, "foo") + + assert not in_transaction(aconn) + assert await inserted(aconn) == {"foo"} + + +async def test_rollback_on_exception_exit(aconn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise ExpectedException("This discards the insert") + + assert not in_transaction(aconn) + assert not await inserted(aconn) + + +@pytest.mark.crdb_skip("pg_terminate_backend") +async def test_context_inerror_rollback_no_clobber( + aconn_cls, aconn, apipeline, dsn, caplog +): + if apipeline: + # Only 'aconn' is possibly in pipeline mode, but the transaction and + # checks are on 'conn2'. + pytest.skip("not applicable") + caplog.set_level(logging.WARNING, logger="psycopg") + + with pytest.raises(ZeroDivisionError): + async with await aconn_cls.connect(dsn) as conn2: + async with conn2.transaction(): + await conn2.execute("select 1") + await aconn.execute( + "select pg_terminate_backend(%s::int)", + [conn2.pgconn.backend_pid], + ) + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + + +@pytest.mark.crdb_skip("copy") +async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg") + + conn = await aconn_cls.connect(dsn) + try: + with pytest.raises(ZeroDivisionError): + async with conn.transaction(): + conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout") + status = conn.info.transaction_status + assert status == conn.TransactionStatus.ACTIVE + 1 / 0 + + assert len(caplog.records) == 1 + rec = caplog.records[0] + assert rec.levelno == logging.WARNING + assert "in rollback" in rec.message + finally: + await conn.close() + + +async def test_interaction_dbapi_transaction(aconn): + await insert_row(aconn, "foo") + + async with aconn.transaction(): + await insert_row(aconn, "bar") + raise Rollback + + async with aconn.transaction(): + await insert_row(aconn, "baz") + + assert in_transaction(aconn) + await aconn.commit() + assert await inserted(aconn) == {"foo", "baz"} + + +async def test_prohibits_use_of_commit_rollback_autocommit(aconn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + async with aconn.transaction(): + with pytest.raises(e.ProgrammingError): + await aconn.set_autocommit(False) + with pytest.raises(e.ProgrammingError): + await aconn.commit() + with pytest.raises(e.ProgrammingError): + await aconn.rollback() + + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +async def test_preserves_autocommit(aconn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + await aconn.set_autocommit(autocommit) + async with aconn.transaction(): + assert aconn.autocommit is autocommit + assert aconn.autocommit is autocommit + + +async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert not in_transaction(aconn) + + # Changes committed + assert await inserted(aconn) == {"new"} + assert inserted(svcconn) == {"new"} + + +async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert not in_transaction(aconn) + + # Changes discarded + assert not await inserted(aconn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_autocommit_off_and_tx_in_progress_successful_exit( + aconn, apipeline, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior", "new"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_autocommit_off_and_tx_in_progress_exception_exit( + aconn, apipeline, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + if apipeline: + await apipeline.sync() + assert in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction(): + await insert_row(aconn, "inner") + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} + + +async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +async def test_nested_three_levels_successful_exit(aconn, svcconn): + """Exercise management of more than one savepoint.""" + async with aconn.transaction(): # BEGIN + await insert_row(aconn, "one") + async with aconn.transaction(): # SAVEPOINT s1 + await insert_row(aconn, "two") + async with aconn.transaction(): # SAVEPOINT s2 + await insert_row(aconn, "three") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} + + +async def test_named_savepoint_escapes_savepoint_name(aconn): + async with aconn.transaction("s-1"): + pass + async with aconn.transaction("s1; drop table students"): + pass + + +async def test_named_savepoints_successful_exit(aconn, acommands): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + commands = acommands + + # Case 1 + # Using Transaction explicitly because conn.transaction() enters the contetx + async with aconn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + assert commands.popall() == ["COMMIT"] + + # Case 1 (with a transaction already started) + await aconn.cursor().execute("select 1") + assert commands.popall() == ["BEGIN"] + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_1"'] + assert tx.savepoint_name == "_pg3_1" + + assert commands.popall() == ['RELEASE "_pg3_1"'] + await aconn.rollback() + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + async with aconn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + async with aconn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + assert commands.popall() == ['RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + assert commands.popall() == ['RELEASE "_pg3_2"'] + assert commands.popall() == ["COMMIT"] + + +async def test_named_savepoints_exception_exit(aconn, acommands): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + commands = acommands + + # Case 1 + with pytest.raises(ExpectedException): + async with aconn.transaction() as tx: + assert commands.popall() == ["BEGIN"] + assert not tx.savepoint_name + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 2 + with pytest.raises(ExpectedException): + async with aconn.transaction(savepoint_name="foo") as tx: + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] + assert tx.savepoint_name == "foo" + raise ExpectedException + assert commands.popall() == ["ROLLBACK"] + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + async with aconn.transaction(savepoint_name="bar") as tx: + assert commands.popall() == ['SAVEPOINT "bar"'] + assert tx.savepoint_name == "bar" + raise ExpectedException + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] + assert commands.popall() == ["COMMIT"] + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.popall() == ["BEGIN"] + with pytest.raises(ExpectedException): + async with aconn.transaction() as tx: + assert commands.popall() == ['SAVEPOINT "_pg3_2"'] + assert tx.savepoint_name == "_pg3_2" + raise ExpectedException + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] + assert commands.popall() == ["COMMIT"] + + +async def test_named_savepoints_with_repeated_names_works(aconn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("sp"): + await insert_row(aconn, "tx1") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx2") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1", "tx2", "tx3"} + + # Works correctly if one level of inner transaction is rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1", force_rollback=True): + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1") as tx2: + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + raise Rollback(tx2) + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + +async def test_force_rollback_successful_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_force_rollback_exception_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + raise ExpectedException() + assert not await inserted(aconn) + assert not inserted(svcconn) + + +@crdb_skip_external_observer +async def test_explicit_rollback_discards_changes(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + + async def assert_no_rows(): + assert not await inserted(aconn) + assert not inserted(svcconn) + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback + await assert_no_rows() + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback() + await assert_no_rows() + + async with aconn.transaction() as tx: + await insert_row(aconn, "foo") + raise Rollback(tx) + await assert_no_rows() + + +@crdb_skip_external_observer +async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + async with aconn.transaction(): + await insert_row(aconn, "before") + async with aconn.transaction(): + await insert_row(aconn, "during") + raise Rollback + assert in_transaction(aconn) + assert not inserted(svcconn) + await insert_row(aconn, "after") + assert await inserted(aconn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} + + +async def test_explicit_rollback_of_outer_transaction(aconn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + async with aconn.transaction() as outer_tx: + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert not await inserted(aconn) + + +@crdb_skip_external_observer +async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction() as tx_enclosing: + await insert_row(aconn, "enclosing") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(tx_enclosing) + await insert_row(aconn, "outer-after") + + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +async def test_str(aconn, apipeline): + async with aconn.transaction() as tx: + if apipeline: + assert "[INTRANS]" not in str(tx) + await apipeline.sync() + assert "[INTRANS, pipeline=ON]" in str(tx) + else: + assert "[INTRANS]" in str(tx) + assert "(active)" in str(tx) + assert "'" not in str(tx) + async with aconn.transaction("wat") as tx2: + if apipeline: + assert "[INTRANS, pipeline=ON]" in str(tx2) + else: + assert "[INTRANS]" in str(tx2) + assert "'wat'" in str(tx2) + + if apipeline: + assert "[IDLE, pipeline=ON]" in str(tx) + else: + assert "[IDLE]" in str(tx) + assert "(terminated)" in str(tx) + + with pytest.raises(ZeroDivisionError): + async with aconn.transaction() as tx: + 1 / 0 + + assert "(terminated)" in str(tx) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_exit(aconn, exit_error): + await aconn.set_autocommit(True) + + t1 = aconn.transaction() + await t1.__aenter__() + + t2 = aconn.transaction() + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_implicit_begin(aconn, exit_error): + await aconn.execute("select 1") + + t1 = aconn.transaction() + await t1.__aenter__() + + t2 = aconn.transaction() + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_exit_same_name(aconn, exit_error): + await aconn.set_autocommit(True) + + t1 = aconn.transaction("save") + await t1.__aenter__() + t2 = aconn.transaction("save") + await t2.__aenter__() + + with pytest.raises(e.ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(e.ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + +@pytest.mark.parametrize("what", ["commit", "rollback", "error"]) +async def test_concurrency(aconn, what): + await aconn.set_autocommit(True) + + evs = [asyncio.Event() for i in range(3)] + + async def worker(unlock, wait_on): + with pytest.raises(e.ProgrammingError) as ex: + async with aconn.transaction(): + unlock.set() + await wait_on.wait() + await aconn.execute("select 1") + + if what == "error": + 1 / 0 + elif what == "rollback": + raise Rollback() + else: + assert what == "commit" + + if what == "error": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, ZeroDivisionError) + elif what == "rollback": + assert "transaction rollback" in str(ex.value) + assert isinstance(ex.value.__context__, Rollback) + else: + assert "transaction commit" in str(ex.value) + + # Start a first transaction in a task + t1 = create_task(worker(unlock=evs[0], wait_on=evs[1])) + await evs[0].wait() + + # Start a nested transaction in a task + t2 = create_task(worker(unlock=evs[1], wait_on=evs[2])) + + # Terminate the first transaction before the second does + await asyncio.gather(t1) + evs[2].set() + await asyncio.gather(t2) diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py new file mode 100644 index 0000000..d0e57e6 --- /dev/null +++ b/tests/test_typeinfo.py @@ -0,0 +1,145 @@ +import pytest + +import psycopg +from psycopg import sql +from psycopg.pq import TransactionStatus +from psycopg.types import TypeInfo + + +@pytest.mark.parametrize("name", ["text", sql.Identifier("text")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +def test_fetch(conn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + conn.execute("select 1") + + assert conn.info.transaction_status == status + info = TypeInfo.fetch(conn, name) + assert conn.info.transaction_status == status + + assert info.name == "text" + # TODO: add the schema? + # assert info.schema == "pg_catalog" + + assert info.oid == psycopg.adapters.types["text"].oid + assert info.array_oid == psycopg.adapters.types["text"].array_oid + assert info.regtype == "text" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["text", sql.Identifier("text")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +async def test_fetch_async(aconn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + await aconn.execute("select 1") + + assert aconn.info.transaction_status == status + info = await TypeInfo.fetch(aconn, name) + assert aconn.info.transaction_status == status + + assert info.name == "text" + # assert info.schema == "pg_catalog" + assert info.oid == psycopg.adapters.types["text"].oid + assert info.array_oid == psycopg.adapters.types["text"].array_oid + + +@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +def test_fetch_not_found(conn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + conn.execute("select 1") + + assert conn.info.transaction_status == status + info = TypeInfo.fetch(conn, name) + assert conn.info.transaction_status == status + assert info is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")]) +@pytest.mark.parametrize("status", ["IDLE", "INTRANS"]) +async def test_fetch_not_found_async(aconn, name, status): + status = getattr(TransactionStatus, status) + if status == TransactionStatus.INTRANS: + await aconn.execute("select 1") + + assert aconn.info.transaction_status == status + info = await TypeInfo.fetch(aconn, name) + assert aconn.info.transaction_status == status + + assert info is None + + +@pytest.mark.crdb_skip("composite") +@pytest.mark.parametrize( + "name", ["testschema.testtype", sql.Identifier("testschema", "testtype")] +) +def test_fetch_by_schema_qualified_string(conn, name): + conn.execute("create schema if not exists testschema") + conn.execute("create type testschema.testtype as (foo text)") + + info = TypeInfo.fetch(conn, name) + assert info.name == "testtype" + # assert info.schema == "testschema" + cur = conn.execute( + """ + select oid, typarray from pg_type + where oid = 'testschema.testtype'::regtype + """ + ) + assert cur.fetchone() == (info.oid, info.array_oid) + + +@pytest.mark.parametrize( + "name", + [ + "text", + # TODO: support these? + # "pg_catalog.text", + # sql.Identifier("text"), + # sql.Identifier("pg_catalog", "text"), + ], +) +def test_registry_by_builtin_name(conn, name): + info = psycopg.adapters.types[name] + assert info.name == "text" + assert info.oid == 25 + + +def test_registry_empty(): + r = psycopg.types.TypesRegistry() + assert r.get("text") is None + with pytest.raises(KeyError): + r["text"] + + +@pytest.mark.parametrize("oid, aoid", [(1, 2), (1, 0), (0, 2), (0, 0)]) +def test_registry_invalid_oid(oid, aoid): + r = psycopg.types.TypesRegistry() + ti = psycopg.types.TypeInfo("test", oid, aoid) + r.add(ti) + assert r["test"] is ti + if oid: + assert r[oid] is ti + if aoid: + assert r[aoid] is ti + with pytest.raises(KeyError): + r[0] + + +def test_registry_copy(): + r = psycopg.types.TypesRegistry(psycopg.postgres.types) + assert r.get("text") is r["text"] is r[25] + assert r["text"].oid == 25 + + +def test_registry_isolated(): + orig = psycopg.postgres.types + tinfo = orig["text"] + r = psycopg.types.TypesRegistry(orig) + tdummy = psycopg.types.TypeInfo("dummy", tinfo.oid, tinfo.array_oid) + r.add(tdummy) + assert r[25] is r["dummy"] is tdummy + assert orig[25] is r["text"] is tinfo diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..fff9cec --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,449 @@ +import os + +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.mark.parametrize( + "filename", + ["adapters_example.py", "typing_example.py"], +) +def test_typing_example(mypy, filename): + cp = mypy.run_on_file(os.path.join(HERE, filename)) + errors = cp.stdout.decode("utf8", "replace").splitlines() + assert not errors + assert cp.returncode == 0 + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.tuple_row)", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.namedtuple_row)", + "psycopg.Connection[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.Connection.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.Connection.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "conn.cursor()", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor()", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.Cursor[Thing]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor()", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn, row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "psycopg.ServerCursor(conn, 'foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)", + "psycopg.ServerCursor[NamedTuple]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type_init(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Optional[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Optional[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Optional[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchone_type(conn_class, server_side, curs, type, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.fetchone() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Tuple[Any, ...]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Dict[str, Any]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Thing", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_iter_type(conn_class, server_side, curs, type, mypy): + if "Async" in conn_class: + async_ = "async " + await_ = "await " + else: + async_ = await_ = "" + + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_}psycopg.{conn_class}.connect() +curs = {curs} +{async_}for obj in curs: + pass +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("method", ["fetchmany", "fetchall"]) +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "List[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "List[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "List[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.{method}() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_cur_subclass_execute(mypy, conn_class, server_side): + async_ = "async " if "Async" in conn_class else "" + await_ = "await" if "Async" in conn_class else "" + cur_base_class = "".join( + [ + "Async" if "Async" in conn_class else "", + "Server" if server_side else "", + "Cursor", + ] + ) + cur_name = "'foo'" if server_side else "" + + src = f"""\ +from typing import Any, cast +import psycopg +from psycopg.rows import Row, TupleRow + +class MyCursor(psycopg.{cur_base_class}[Row]): + pass + +{async_}def test() -> None: + conn = {await_} psycopg.{conn_class}.connect() + + cur: MyCursor[TupleRow] + reveal_type(cur) + + cur = cast(MyCursor[TupleRow], conn.cursor({cur_name})) + {async_}with cur as cur2: + reveal_type(cur2) + cur3 = {await_} cur2.execute("") + reveal_type(cur3) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 3 + types = [mypy.get_revealed(line) for line in out] + assert types[0] == types[1] + assert types[0] == types[2] + + +def _test_reveal(stmts, type, mypy): + ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]" + stmts = "\n".join(f" {line}" for line in stmts.splitlines()) + + src = f"""\ +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence +from typing import Tuple, Union +import psycopg +from psycopg import rows + +class Thing: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + +def thing_row( + cur: Union[psycopg.Cursor[Any], psycopg.AsyncCursor[Any]], +) -> Callable[[Sequence[Any]], Thing]: + assert cur.description + names = [d.name for d in cur.description] + + def make_row(t: Sequence[Any]) -> Thing: + return Thing(**dict(zip(names, t))) + + return make_row + +async def tmp() -> None: +{stmts} + reveal_type(obj) + +ref: {type} = None {ignore} +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want + + +@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308") +@pytest.mark.parametrize( + "conn, type", + [ + ( + "MyConnection.connect()", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.tuple_row)", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.dict_row)", + "MyConnection[Dict[str, Any]]", + ), + ], +) +def test_generic_connect(conn, type, mypy): + src = f""" +from typing import Any, Dict, Tuple +import psycopg +from psycopg import rows + +class MyConnection(psycopg.Connection[rows.Row]): + pass + +obj = {conn} +reveal_type(obj) + +ref: {type} = None # type: ignore[assignment] +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want diff --git a/tests/test_waiting.py b/tests/test_waiting.py new file mode 100644 index 0000000..63237e8 --- /dev/null +++ b/tests/test_waiting.py @@ -0,0 +1,159 @@ +import select # noqa: used in pytest.mark.skipif +import socket +import sys + +import pytest + +import psycopg +from psycopg import waiting +from psycopg import generators +from psycopg.pq import ConnStatus, ExecStatus + +skip_if_not_linux = pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="non-Linux platform" +) + +waitfns = [ + "wait", + "wait_selector", + pytest.param( + "wait_select", marks=pytest.mark.skipif("not hasattr(select, 'select')") + ), + pytest.param( + "wait_epoll", marks=pytest.mark.skipif("not hasattr(select, 'epoll')") + ), + pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")), +] + +timeouts = [pytest.param({}, id="blank")] +timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]] + + +@pytest.mark.parametrize("timeout", timeouts) +def test_wait_conn(dsn, timeout): + gen = generators.connect(dsn) + conn = waiting.wait_conn(gen, **timeout) + assert conn.status == ConnStatus.OK + + +def test_wait_conn_bad(dsn): + gen = generators.connect("dbname=nosuchdb") + with pytest.raises(psycopg.OperationalError): + waiting.wait_conn(gen) + + +@pytest.mark.parametrize("waitfn", waitfns) +@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@skip_if_not_linux +def test_wait_ready(waitfn, wait, ready): + waitfn = getattr(waiting, waitfn) + + def gen(): + r = yield wait + return r + + with socket.socket() as s: + r = waitfn(gen(), s.fileno()) + assert r & ready + + +@pytest.mark.parametrize("waitfn", waitfns) +@pytest.mark.parametrize("timeout", timeouts) +def test_wait(pgconn, waitfn, timeout): + waitfn = getattr(waiting, waitfn) + + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + (res,) = waitfn(gen, pgconn.socket, **timeout) + assert res.status == ExecStatus.TUPLES_OK + + +@pytest.mark.parametrize("waitfn", waitfns) +def test_wait_bad(pgconn, waitfn): + waitfn = getattr(waiting, waitfn) + + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + waitfn(gen, pgconn.socket) + + +@pytest.mark.slow +@pytest.mark.skipif( + "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious" +) +@pytest.mark.parametrize("waitfn", waitfns) +def test_wait_large_fd(dsn, waitfn): + waitfn = getattr(waiting, waitfn) + + files = [] + try: + try: + for i in range(1100): + files.append(open(__file__)) + except OSError: + pytest.skip("can't open the number of files needed for the test") + + pgconn = psycopg.pq.PGconn.connect(dsn.encode()) + try: + assert pgconn.socket > 1024 + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + if waitfn is waiting.wait_select: + with pytest.raises(ValueError): + waitfn(gen, pgconn.socket) + else: + (res,) = waitfn(gen, pgconn.socket) + assert res.status == ExecStatus.TUPLES_OK + finally: + pgconn.finish() + finally: + for f in files: + f.close() + + +@pytest.mark.parametrize("timeout", timeouts) +@pytest.mark.asyncio +async def test_wait_conn_async(dsn, timeout): + gen = generators.connect(dsn) + conn = await waiting.wait_conn_async(gen, **timeout) + assert conn.status == ConnStatus.OK + + +@pytest.mark.asyncio +async def test_wait_conn_async_bad(dsn): + gen = generators.connect("dbname=nosuchdb") + with pytest.raises(psycopg.OperationalError): + await waiting.wait_conn_async(gen) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@skip_if_not_linux +async def test_wait_ready_async(wait, ready): + def gen(): + r = yield wait + return r + + with socket.socket() as s: + r = await waiting.wait_async(gen(), s.fileno()) + assert r & ready + + +@pytest.mark.asyncio +async def test_wait_async(pgconn): + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + (res,) = await waiting.wait_async(gen, pgconn.socket) + assert res.status == ExecStatus.TUPLES_OK + + +@pytest.mark.asyncio +async def test_wait_async_bad(pgconn): + pgconn.send_query(b"select 1") + gen = generators.execute(pgconn) + socket = pgconn.socket + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + await waiting.wait_async(gen, socket) diff --git a/tests/test_windows.py b/tests/test_windows.py new file mode 100644 index 0000000..09e61ba --- /dev/null +++ b/tests/test_windows.py @@ -0,0 +1,23 @@ +import pytest +import asyncio +import sys + +from psycopg.errors import InterfaceError + + +@pytest.mark.skipif(sys.platform != "win32", reason="windows only test") +def test_windows_error(aconn_cls, dsn): + loop = asyncio.ProactorEventLoop() # type: ignore[attr-defined] + + async def go(): + with pytest.raises( + InterfaceError, + match="Psycopg cannot use the 'ProactorEventLoop'", + ): + await aconn_cls.connect(dsn) + + try: + loop.run_until_complete(go()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/types/__init__.py diff --git a/tests/types/test_array.py b/tests/types/test_array.py new file mode 100644 index 0000000..74c17a6 --- /dev/null +++ b/tests/types/test_array.py @@ -0,0 +1,338 @@ +from typing import List, Any +from decimal import Decimal + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat, Transformer, Dumper +from psycopg.types import TypeInfo +from psycopg._compat import prod +from psycopg.postgres import types as builtins + + +tests_str = [ + ([[[[[["a"]]]]]], "{{{{{{a}}}}}}"), + ([[[[[[None]]]]]], "{{{{{{NULL}}}}}}"), + ([[[[[["NULL"]]]]]], '{{{{{{"NULL"}}}}}}'), + (["foo", "bar", "baz"], "{foo,bar,baz}"), + (["foo", None, "baz"], "{foo,null,baz}"), + (["foo", "null", "", "baz"], '{foo,"null","",baz}'), + ( + [["foo", "bar"], ["baz", "qux"], ["quux", "quuux"]], + "{{foo,bar},{baz,qux},{quux,quuux}}", + ), + ( + [[["fo{o", "ba}r"], ['ba"z', "qu'x"], ["qu ux", " "]]], + r'{{{"fo{o","ba}r"},{"ba\"z",qu\'x},{"qu ux"," "}}}', + ), +] + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("type", ["text", "int4"]) +def test_dump_empty_list(conn, fmt_in, type): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}::{type}[] = %s::{type}[]", ([], "{}")) + assert cur.fetchone()[0] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("obj, want", tests_str) +def test_dump_list_str(conn, obj, want, fmt_in): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}::text[] = %s::text[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_empty_list_str(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::text[]", ([],)) + assert cur.fetchone()[0] == [] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("want, obj", tests_str) +def test_load_list_str(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::text[]", (obj,)) + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_all_chars(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + c = chr(i) + cur.execute(f"select %{fmt_in.value}::text[]", ([c],)) + assert cur.fetchone()[0] == [c] + + a = list(map(chr, range(1, 256))) + a.append("\u20ac") + cur.execute(f"select %{fmt_in.value}::text[]", (a,)) + assert cur.fetchone()[0] == a + + s = "".join(a) + cur.execute(f"select %{fmt_in.value}::text[]", ([s],)) + assert cur.fetchone()[0] == [s] + + +tests_int = [ + ([10, 20, -30], "{10,20,-30}"), + ([10, None, 30], "{10,null,30}"), + ([[10, 20], [30, 40]], "{{10,20},{30,40}}"), +] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("obj, want", tests_int) +def test_dump_list_int(conn, obj, want): + cur = conn.cursor() + cur.execute("select %s::int[] = %s::int[]", (obj, want)) + assert cur.fetchone()[0] + + +@pytest.mark.parametrize( + "input", + [ + [["a"], ["b", "c"]], + [["a"], []], + [[["a"]], ["b"]], + # [["a"], [["b"]]], # todo, but expensive (an isinstance per item) + # [True, b"a"], # TODO expensive too + ], +) +def test_bad_binary_array(input): + tx = Transformer() + with pytest.raises(psycopg.DataError): + tx.get_dumper(input, PyFormat.BINARY).dump(input) + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("want, obj", tests_int) +def test_load_list_int(conn, obj, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::int[]", (obj,)) + assert cur.fetchone()[0] == want + + stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format( + obj, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["int4[]"]) + (got,) = copy.read_row() + + assert got == want + + +@pytest.mark.crdb_skip("composite") +def test_array_register(conn): + conn.execute("create table mytype (data text)") + cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""") + res = cur.fetchone() + assert res[0] == "(foo)" + assert res[1] == "{(foo)}" + + info = TypeInfo.fetch(conn, "mytype") + info.register(conn) + + cur = conn.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[]""") + res = cur.fetchone() + assert res[0] == "(foo)" + assert res[1] == ["(foo)"] + + +@pytest.mark.crdb("skip", reason="aclitem") +def test_array_of_unknown_builtin(conn): + user = conn.execute("select user").fetchone()[0] + # we cannot load this type, but we understand it is an array + val = f"{user}=arwdDxt/{user}" + cur = conn.execute(f"select '{val}'::aclitem, array['{val}']::aclitem[]") + res = cur.fetchone() + assert cur.description[0].type_code == builtins["aclitem"].oid + assert res[0] == val + assert cur.description[1].type_code == builtins["aclitem"].array_oid + assert res[1] == [val] + + +@pytest.mark.parametrize( + "num, type", + [ + (0, "int2"), + (2**15 - 1, "int2"), + (-(2**15), "int2"), + (2**15, "int4"), + (2**31 - 1, "int4"), + (-(2**31), "int4"), + (2**31, "int8"), + (2**63 - 1, "int8"), + (-(2**63), "int8"), + (2**63, "numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_numbers_array(num, type, fmt_in): + for array in ([num], [1, num]): + tx = Transformer() + dumper = tx.get_dumper(array, fmt_in) + dumper.dump(array) + assert dumper.oid == builtins[type].array_oid + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out): + wrapper = getattr(psycopg.types.numeric, wrapper) + if wrapper is Decimal: + want_cls = Decimal + else: + assert wrapper.__mro__[1] in (int, float) + want_cls = wrapper.__mro__[1] + + obj = [wrapper(1), wrapper(0), wrapper(-1), None] + cur = conn.cursor(binary=fmt_out) + got = cur.execute(f"select %{fmt_in.value}", [obj]).fetchone()[0] + assert got == obj + for i in got: + if i is not None: + assert type(i) is want_cls + + +def test_mix_types(conn): + with pytest.raises(psycopg.DataError): + conn.execute("select %s", ([1, 0.5],)) + + with pytest.raises(psycopg.DataError): + conn.execute("select %s", ([1, Decimal("0.5")],)) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list_mix(conn, fmt_in): + objs = list(range(3)) + conn.execute("create table testarrays (col1 bigint[], col2 bigint[])") + # pro tip: don't get confused with the types + f1, f2 = conn.execute( + f"insert into testarrays values (%{fmt_in.value}, %{fmt_in.value}) returning *", + (objs, []), + ).fetchone() + assert f1 == objs + assert f2 == [] + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table test (id serial primary key, data date[])") + with conn.transaction(): + cur.execute( + f"insert into test (data) values (%{fmt_in.value}) returning id", ([],) + ) + id = cur.fetchone()[0] + cur.execute("select data from test") + assert cur.fetchone() == ([],) + + # test untyped list in a filter + cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([id],)) + assert cur.fetchone() + cur.execute(f"select data from test where id = any(%{fmt_in.value})", ([],)) + assert not cur.fetchone() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_empty_list_after_choice(conn, fmt_in): + cur = conn.cursor() + cur.execute("create table test (id serial primary key, data float[])") + cur.executemany( + f"insert into test (data) values (%{fmt_in.value})", [([1.0],), ([],)] + ) + cur.execute("select data from test order by id") + assert cur.fetchall() == [([1.0],), ([],)] + + +@pytest.mark.crdb_skip("geometric types") +def test_dump_list_no_comma_separator(conn): + class Box: + def __init__(self, x1, y1, x2, y2): + self.coords = (x1, y1, x2, y2) + + class BoxDumper(Dumper): + + format = pq.Format.TEXT + oid = psycopg.postgres.types["box"].oid + + def dump(self, box): + return ("(%s,%s),(%s,%s)" % box.coords).encode() + + conn.adapters.register_dumper(Box, BoxDumper) + + cur = conn.execute("select (%s::box)::text", (Box(1, 2, 3, 4),)) + got = cur.fetchone()[0] + assert got == "(3,4),(1,2)" + + cur = conn.execute( + "select (%s::box[])::text", ([Box(1, 2, 3, 4), Box(5, 4, 3, 2)],) + ) + got = cur.fetchone()[0] + assert got == "{(3,4),(1,2);(5,4),(3,2)}" + + +@pytest.mark.crdb_skip("geometric types") +def test_load_array_no_comma_separator(conn): + cur = conn.execute("select '{(2,2),(1,1);(5,6),(3,4)}'::box[]") + # Not parsed at the moment, but split ok on ; separator + assert cur.fetchone()[0] == ["(2,2),(1,1)", "(5,6),(3,4)"] + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_nested_array(conn, fmt_out): + dims = [3, 4, 5, 6] + a: List[Any] = list(range(prod(dims))) + for dim in dims[-1:0:-1]: + a = [a[i : i + dim] for i in range(0, len(a), dim)] + + assert a[2][3][4][5] == prod(dims) - 1 + + sa = str(a).replace("[", "{").replace("]", "}") + got = conn.execute("select %s::int[][][][]", [sa], binary=fmt_out).fetchone()[0] + assert got == a + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "obj, want", + [ + ("'[0:1]={a,b}'::text[]", ["a", "b"]), + ("'[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}'::int[]", [[[1, 2, 3], [4, 5, 6]]]), + ], +) +def test_array_with_bounds(conn, obj, want, fmt_out): + got = conn.execute(f"select {obj}", binary=fmt_out).fetchone()[0] + assert got == want + + +@pytest.mark.crdb_skip("nested array") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_all_chars_with_bounds(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + c = chr(i) + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([c],)) + assert cur.fetchone()[0] == ["a", "b", c] + + a = list(map(chr, range(1, 256))) + a.append("\u20ac") + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", (a,)) + assert cur.fetchone()[0] == ["a", "b"] + a + + s = "".join(a) + cur.execute("select '[0:1]={a,b}'::text[] || %s::text[]", ([s],)) + assert cur.fetchone()[0] == ["a", "b", s] diff --git a/tests/types/test_bool.py b/tests/types/test_bool.py new file mode 100644 index 0000000..edd4dad --- /dev/null +++ b/tests/types/test_bool.py @@ -0,0 +1,47 @@ +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat +from psycopg.postgres import types as builtins + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("b", [True, False]) +def test_roundtrip_bool(conn, b, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + result = cur.execute(f"select %{fmt_in.value}", (b,)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].oid + assert result is b + + result = cur.execute(f"select %{fmt_in.value}", ([b],)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].array_oid + assert result[0] is b + + +@pytest.mark.parametrize("val", [True, False]) +def test_quote_bool(conn, val): + + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode( + "ascii" + ) + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val))) + assert cur.fetchone()[0] is val + + +def test_quote_none(conn): + + tx = Transformer() + assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL" + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None))) + assert cur.fetchone()[0] is None diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py new file mode 100644 index 0000000..47beecf --- /dev/null +++ b/tests/types/test_composite.py @@ -0,0 +1,396 @@ +import pytest + +from psycopg import pq, postgres, sql +from psycopg.adapt import PyFormat +from psycopg.postgres import types as builtins +from psycopg.types.range import Range +from psycopg.types.composite import CompositeInfo, register_composite +from psycopg.types.composite import TupleDumper, TupleBinaryDumper + +from ..utils import eur +from ..fix_crdb import is_crdb, crdb_skip_message + + +pytestmark = pytest.mark.crdb_skip("composite") + +tests_str = [ + ("", ()), + # Funnily enough there's no way to represent (None,) in Postgres + ("null", ()), + ("null,null", (None, None)), + ("null, ''", (None, "")), + ( + "42,'foo','ba,r','ba''z','qu\"x'", + ("42", "foo", "ba,r", "ba'z", 'qu"x'), + ), + ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"')), +] + + +@pytest.mark.parametrize("rec, want", tests_str) +def test_load_record(conn, want, rec): + cur = conn.cursor() + res = cur.execute(f"select row({rec})").fetchone()[0] + assert res == want + + +@pytest.mark.parametrize("rec, obj", tests_str) +def test_dump_tuple(conn, rec, obj): + cur = conn.cursor() + fields = [f"f{i} text" for i in range(len(obj))] + cur.execute( + f""" + drop type if exists tmptype; + create type tmptype as ({', '.join(fields)}); + """ + ) + info = CompositeInfo.fetch(conn, "tmptype") + register_composite(info, conn) + + res = conn.execute("select %s::tmptype", [obj]).fetchone()[0] + assert res == obj + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_all_chars(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0] + assert res == (chr(i),) + + cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256))) + res = cur.fetchone()[0] + assert res == tuple(map(chr, range(1, 256))) + + s = "".join(map(chr, range(1, 256))) + res = cur.execute("select row(%s::text)", [s]).fetchone()[0] + assert res == (s,) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_range(conn, fmt_in): + conn.execute( + """ + drop type if exists tmptype; + create type tmptype as (num integer, range daterange, nums integer[]) + """ + ) + info = CompositeInfo.fetch(conn, "tmptype") + register_composite(info, conn) + + cur = conn.execute( + f"select pg_typeof(%{fmt_in.value})", + [info.python_type(10, Range(empty=True), [])], + ) + assert cur.fetchone()[0] == "tmptype" + + +@pytest.mark.parametrize( + "rec, want", + [ + ("", ()), + ("null", (None,)), # Unlike text format, this is a thing + ("null,null", (None, None)), + ("null, ''", (None, b"")), + ( + "42,'foo','ba,r','ba''z','qu\"x'", + (42, b"foo", b"ba,r", b"ba'z", b'qu"x'), + ), + ( + "'foo''', '''foo', '\"bar', 'bar\"' ", + (b"foo'", b"'foo", b'"bar', b'bar"'), + ), + ( + "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ", + (10, None, 20.0, None, "foo", b"bar"), + ), + ], +) +def test_load_record_binary(conn, want, rec): + cur = conn.cursor(binary=True) + res = cur.execute(f"select row({rec})").fetchone()[0] + assert res == want + for o1, o2 in zip(res, want): + assert type(o1) is type(o2) + + +@pytest.fixture(scope="session") +def testcomp(svcconn): + if is_crdb(svcconn): + pytest.skip(crdb_skip_message("composite")) + cur = svcconn.cursor() + cur.execute( + """ + create schema if not exists testschema; + + drop type if exists testcomp cascade; + drop type if exists testschema.testcomp cascade; + + create type testcomp as (foo text, bar int8, baz float8); + create type testschema.testcomp as (foo text, bar int8, qux bool); + """ + ) + return CompositeInfo.fetch(svcconn, "testcomp") + + +fetch_cases = [ + ( + "testcomp", + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + "testschema.testcomp", + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ( + sql.Identifier("testcomp"), + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + sql.Identifier("testschema", "testcomp"), + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), +] + + +@pytest.mark.parametrize("name, fields", fetch_cases) +def test_fetch_info(conn, testcomp, name, fields): + info = CompositeInfo.fetch(conn, name) + assert info.name == "testcomp" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 + for i, (name, t) in enumerate(fields): + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, fields", fetch_cases) +async def test_fetch_info_async(aconn, testcomp, name, fields): + info = await CompositeInfo.fetch(aconn, name) + assert info.name == "testcomp" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 + for i, (name, t) in enumerate(fields): + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid + + +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_tuple_all_chars(conn, fmt_in, testcomp): + cur = conn.cursor() + for i in range(1, 256): + (res,) = cur.execute( + f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}::testcomp", + (i, (chr(i), 1, 1.0)), + ).fetchone() + assert res is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_composite_all_chars(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + for i in range(1, 256): + obj = factory(chr(i), 1, 1.0) + (res,) = cur.execute( + f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in.value}", (i, obj) + ).fetchone() + assert res is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_composite_null(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + + obj = factory("foo", 1, None) + rec = cur.execute( + f""" + select row('foo', 1, NULL)::testcomp = %(obj){fmt_in.value}, + %(obj){fmt_in.value}::text + """, + {"obj": obj}, + ).fetchone() + assert rec[0] is True, rec[1] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_composite(conn, testcomp, fmt_out): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + + cur = conn.cursor(binary=fmt_out) + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert res.foo == "hello" + assert res.bar == 10 + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float) + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_composite_factory(conn, testcomp, fmt_out): + info = CompositeInfo.fetch(conn, "testcomp") + + class MyThing: + def __init__(self, *args): + self.foo, self.bar, self.baz = args + + register_composite(info, conn, factory=MyThing) + assert info.python_type is MyThing + + cur = conn.cursor(binary=fmt_out) + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert isinstance(res, MyThing) + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float) + + +def test_register_scope(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert postgres.adapters._loaders[fmt].pop(oid) + + for f in PyFormat: + assert postgres.adapters._dumpers[f].pop(info.python_type) + + cur = conn.cursor() + register_composite(info, cur) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert oid not in postgres.adapters._loaders[fmt] + assert oid not in conn.adapters._loaders[fmt] + assert oid in cur.adapters._loaders[fmt] + + register_composite(info, conn) + for fmt in pq.Format: + for oid in (info.oid, info.array_oid): + assert oid not in postgres.adapters._loaders[fmt] + assert oid in conn.adapters._loaders[fmt] + + +def test_type_dumper_registered(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + assert issubclass(info.python_type, tuple) + assert info.python_type.__name__ == "testcomp" + d = conn.adapters.get_dumper(info.python_type, "s") + assert issubclass(d, TupleDumper) + assert d is not TupleDumper + + tc = info.python_type("foo", 42, 3.14) + cur = conn.execute("select pg_typeof(%s)", [tc]) + assert cur.fetchone()[0] == "testcomp" + + +def test_type_dumper_registered_binary(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + assert issubclass(info.python_type, tuple) + assert info.python_type.__name__ == "testcomp" + d = conn.adapters.get_dumper(info.python_type, "b") + assert issubclass(d, TupleBinaryDumper) + assert d is not TupleBinaryDumper + + tc = info.python_type("foo", 42, 3.14) + cur = conn.execute("select pg_typeof(%b)", [tc]) + assert cur.fetchone()[0] == "testcomp" + + +def test_callable_dumper_not_registered(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + + def fac(*args): + return args + (args[-1],) + + register_composite(info, conn, factory=fac) + assert info.python_type is None + + # but the loader is registered + cur = conn.execute("select '(foo,42,3.14)'::testcomp") + assert cur.fetchone()[0] == ("foo", 42, 3.14, 3.14) + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="composite"): + register_composite(None, conn) # type: ignore[arg-type] + + +def test_invalid_fields_names(conn): + conn.execute("set client_encoding to utf8") + conn.execute( + f""" + create type "a-b" as ("c-d" text, "{eur}" int); + create type "-x-{eur}" as ("w-ww" "a-b", "0" int); + """ + ) + ab = CompositeInfo.fetch(conn, '"a-b"') + x = CompositeInfo.fetch(conn, f'"-x-{eur}"') + register_composite(ab, conn) + register_composite(x, conn) + obj = x.python_type(ab.python_type("foo", 10), 20) + conn.execute(f"""create table meh (wat "-x-{eur}")""") + conn.execute("insert into meh values (%s)", [obj]) + got = conn.execute("select wat from meh").fetchone()[0] + assert obj == got + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "1", "'"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute( + sql.SQL("create type {name} as (foo text)").format(name=sql.Identifier(name)) + ) + info = CompositeInfo.fetch(conn, sql.Identifier(name).as_string(conn)) + register_composite(info, conn) + obj = info.python_type("hello") + assert sql.Literal(obj).as_string(conn) == f"'(hello)'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + got = cur.fetchone()[0] + assert got == obj + assert type(got) is type(obj) + + +@pytest.mark.parametrize( + "name, attr", + [ + ("a-b", "a_b"), + (f"{eur}", "f_"), + ("üåäö", "üåäö"), + ("order", "order"), + ("1", "f1"), + ], +) +def test_literal_invalid_attr(conn, name, attr): + conn.execute("set client_encoding to utf8") + conn.execute( + sql.SQL("create type test_attr as ({name} text)").format( + name=sql.Identifier(name) + ) + ) + info = CompositeInfo.fetch(conn, "test_attr") + register_composite(info, conn) + obj = info.python_type("hello") + assert getattr(obj, attr) == "hello" + cur = conn.execute(sql.SQL("select {}").format(obj)) + got = cur.fetchone()[0] + assert got == obj + assert type(got) is type(obj) diff --git a/tests/types/test_datetime.py b/tests/types/test_datetime.py new file mode 100644 index 0000000..11fe493 --- /dev/null +++ b/tests/types/test_datetime.py @@ -0,0 +1,813 @@ +import datetime as dt + +import pytest + +from psycopg import DataError, pq, sql +from psycopg.adapt import PyFormat + +crdb_skip_datestyle = pytest.mark.crdb("skip", reason="set datestyle/intervalstyle") +crdb_skip_negative_interval = pytest.mark.crdb("skip", reason="negative interval") +crdb_skip_invalid_tz = pytest.mark.crdb( + "skip", reason="crdb doesn't allow invalid timezones" +) + +datestyles_in = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["DMY", "MDY", "YMD"] +] +datestyles_out = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["ISO", "Postgres", "SQL", "German"] +] + +intervalstyles = [ + pytest.param(datestyle, marks=crdb_skip_datestyle) + for datestyle in ["sql_standard", "postgres", "postgres_verbose", "iso_8601"] +] + + +class TestDate: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_date(self, conn, val, expr, fmt_in): + val = as_date(val) + cur = conn.cursor() + cur.execute(f"select '{expr}'::date = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + cur.execute( + sql.SQL("select {}::date = {}").format( + sql.Literal(val), sql.Placeholder(format=fmt_in) + ), + (val,), + ) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_date_datestyle(self, conn, datestyle_in): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO,{datestyle_in}") + cur.execute("select 'epoch'::date + 1 = %t", (dt.date(1970, 1, 2),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_date(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::date") + assert cur.fetchone()[0] == as_date(val) + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_date_datestyle(self, conn, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select '2000-01-02'::date") + assert cur.fetchone()[0] == dt.date(2000, 1, 2) + + @pytest.mark.parametrize("val", ["min", "max"]) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_date_overflow(self, conn, val, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1)) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_date_overflow_binary(self, conn, val): + cur = conn.cursor(binary=True) + cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1)) + with pytest.raises(DataError): + cur.fetchone()[0] + + overflow_samples = [ + ("-infinity", "date too small"), + ("1000-01-01 BC", "date too small"), + ("10000-01-01", "date too large"), + ("infinity", "date too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_load_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::date", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_load_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::date", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + def test_infinity_date_example(self, conn): + # NOTE: this is an example in the docs. Make sure it doesn't regress when + # adding binary datetime adapters + from datetime import date + from psycopg.types.datetime import DateLoader, DateDumper + + class InfDateDumper(DateDumper): + def dump(self, obj): + if obj == date.max: + return b"infinity" + else: + return super().dump(obj) + + class InfDateLoader(DateLoader): + def load(self, data): + if data == b"infinity": + return date.max + else: + return super().load(data) + + cur = conn.cursor() + cur.adapters.register_dumper(date, InfDateDumper) + cur.adapters.register_loader("date", InfDateLoader) + + rec = cur.execute( + "SELECT %s::text, %s::text", [date(2020, 12, 31), date.max] + ).fetchone() + assert rec == ("2020-12-31", "infinity") + rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone() + assert rec == (date(2020, 12, 31), date(9999, 12, 31)) + + +class TestDatetime: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "0001-01-01 00:00"), + ("258,1,8,1,12,32,358261", "0258-1-8 1:12:32.358261"), + ("1000,1,1,0,0", "1000-01-01 00:00"), + ("2000,1,1,0,0", "2000-01-01 00:00"), + ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"), + ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"), + ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"), + ("2000,1,1,0,0,0,1", "2000-01-01 00:00:00.000001"), + ("2034,02,03,23,34,27,951357", "2034-02-03 23:34:27.951357"), + ("2200,1,1,0,0,0,1", "2200-01-01 00:00:00.000001"), + ("2300,1,1,0,0,0,1", "2300-01-01 00:00:00.000001"), + ("7000,1,1,0,0,0,1", "7000-01-01 00:00:00.000001"), + ("max", "9999-12-31 23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetime(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '+02:00'") + cur.execute(f"select %{fmt_in.value}", (as_dt(val),)) + cur.execute(f"select '{expr}'::timestamp = %{fmt_in.value}", (as_dt(val),)) + cur.execute( + f""" + select '{expr}'::timestamp = %(val){fmt_in.value}, + '{expr}', %(val){fmt_in.value}::text + """, + {"val": as_dt(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (want, got) + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_datetime_datestyle(self, conn, datestyle_in): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO, {datestyle_in}") + cur.execute( + "select 'epoch'::timestamp + '1d 3h 4m 5s'::interval = %t", + (dt.datetime(1970, 1, 2, 3, 4, 5),), + ) + assert cur.fetchone()[0] is True + + load_datetime_samples = [ + ("min", "0001-01-01"), + ("1000,1,1", "1000-01-01"), + ("2000,1,1", "2000-01-01"), + ("2000,1,2,3,4,5,6", "2000-01-02 03:04:05.000006"), + ("2000,1,2,3,4,5,678", "2000-01-02 03:04:05.000678"), + ("2000,1,2,3,0,0,456789", "2000-01-02 03:00:00.456789"), + ("2000,12,31", "2000-12-31"), + ("3000,1,1", "3000-01-01"), + ("max", "9999-12-31 23:59:59.999999"), + ] + + @pytest.mark.parametrize("val, expr", load_datetime_samples) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}") + cur.execute("set timezone to '+02:00'") + cur.execute(f"select '{expr}'::timestamp") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize("val, expr", load_datetime_samples) + def test_load_datetime_binary(self, conn, val, expr): + cur = conn.cursor(binary=True) + cur.execute("set timezone to '+02:00'") + cur.execute(f"select '{expr}'::timestamp") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize("val", ["min", "max"]) + @pytest.mark.parametrize("datestyle_out", datestyles_out) + def test_load_datetime_overflow(self, conn, val, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute( + "select %t::timestamp + %s * '1s'::interval", + (as_dt(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_datetime_overflow_binary(self, conn, val): + cur = conn.cursor(binary=True) + cur.execute( + "select %t::timestamp + %s * '1s'::interval", + (as_dt(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + overflow_samples = [ + ("-infinity", "timestamp too small"), + ("1000-01-01 12:00 BC", "timestamp too small"), + ("10000-01-01 12:00", "timestamp too large"), + ("infinity", "timestamp too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::timestamp", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::timestamp", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @crdb_skip_datestyle + def test_load_all_month_names(self, conn): + cur = conn.cursor(binary=False) + cur.execute("set datestyle = 'Postgres'") + for i in range(12): + d = dt.datetime(2000, i + 1, 15) + cur.execute("select %s", [d]) + assert cur.fetchone()[0] == d + + +class TestDateTimeTz: + @pytest.mark.parametrize( + "val, expr", + [ + ("min~-2", "0001-01-01 00:00-02:00"), + ("min~-12", "0001-01-01 00:00-12:00"), + ( + "258,1,8,1,12,32,358261~1:2:3", + "0258-1-8 1:12:32.358261+01:02:03", + ), + ("1000,1,1,0,0~2", "1000-01-01 00:00+2"), + ("2000,1,1,0,0~2", "2000-01-01 00:00+2"), + ("2000,1,1,0,0~12", "2000-01-01 00:00+12"), + ("2000,1,1,0,0~-12", "2000-01-01 00:00-12"), + ("2000,1,1,0,0~01:02:03", "2000-01-01 00:00+01:02:03"), + ("2000,1,1,0,0~-01:02:03", "2000-01-01 00:00-01:02:03"), + ("2000,12,31,23,59,59,999999~2", "2000-12-31 23:59:59.999999+2"), + ( + "2034,02,03,23,34,27,951357~-4:27", + "2034-02-03 23:34:27.951357-04:27", + ), + ("2300,1,1,0,0,0,1~1", "2300-01-01 00:00:00.000001+1"), + ("3000,1,1,0,0~2", "3000-01-01 00:00+2"), + ("7000,1,1,0,0,0,1~-1:2:3", "7000-01-01 00:00:00.000001-01:02:03"), + ("max~2", "9999-12-31 23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetimetz(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '-02:00'") + cur.execute( + f""" + select '{expr}'::timestamptz = %(val){fmt_in.value}, + '{expr}', %(val){fmt_in.value}::text + """, + {"val": as_dt(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (want, got) + + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_dump_datetimetz_datestyle(self, conn, datestyle_in): + tzinfo = dt.timezone(dt.timedelta(hours=2)) + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = ISO, {datestyle_in}") + cur.execute("set timezone to '-02:00'") + cur.execute( + "select 'epoch'::timestamptz + '1d 3h 4m 5.678s'::interval = %t", + (dt.datetime(1970, 1, 2, 5, 4, 5, 678000, tzinfo=tzinfo),), + ) + assert cur.fetchone()[0] is True + + load_datetimetz_samples = [ + ("2000,1,1~2", "2000-01-01", "-02:00"), + ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"), + ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"), + ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"), + ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"), + ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"), + ("2000,12,31~2", "2000-12-31", "-02:00"), + ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"), + ] + + @crdb_skip_datestyle + @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples) + @pytest.mark.parametrize("datestyle_out", ["ISO"]) + def test_load_datetimetz(self, conn, val, expr, timezone, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, DMY") + cur.execute(f"set timezone to '{timezone}'") + got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0] + assert got == as_dt(val) + + @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples) + def test_load_datetimetz_binary(self, conn, val, expr, timezone): + cur = conn.cursor(binary=True) + cur.execute(f"set timezone to '{timezone}'") + got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0] + assert got == as_dt(val) + + @pytest.mark.xfail # parse timezone names + @crdb_skip_datestyle + @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")]) + @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"]) + @pytest.mark.parametrize("datestyle_in", datestyles_in) + def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out): + cur = conn.cursor(binary=False) + cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}") + cur.execute("set timezone to '-02:00'") + cur.execute(f"select '{expr}'::timestamptz") + assert cur.fetchone()[0] == as_dt(val) + + @pytest.mark.parametrize( + "tzname, expr, tzoff", + [ + ("UTC", "2000-1-1", 0), + ("UTC", "2000-7-1", 0), + ("Europe/Rome", "2000-1-1", 3600), + ("Europe/Rome", "2000-7-1", 7200), + ("Europe/Rome", "1000-1-1", 2996), + pytest.param("NOSUCH0", "2000-1-1", 0, marks=crdb_skip_invalid_tz), + pytest.param("0", "2000-1-1", 0, marks=crdb_skip_invalid_tz), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_datetimetz_tz(self, conn, fmt_out, tzname, expr, tzoff): + conn.execute("select set_config('TimeZone', %s, true)", [tzname]) + cur = conn.cursor(binary=fmt_out) + ts = cur.execute("select %s::timestamptz", [expr]).fetchone()[0] + assert ts.utcoffset().total_seconds() == tzoff + + @pytest.mark.parametrize( + "val, type", + [ + ("2000,1,2,3,4,5,6", "timestamp"), + ("2000,1,2,3,4,5,6~0", "timestamptz"), + ("2000,1,2,3,4,5,6~2", "timestamptz"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_datetime_tz_or_not_tz(self, conn, val, type, fmt_in): + val = as_dt(val) + cur = conn.cursor() + cur.execute( + f""" + select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value} + """, + [val, type, val], + ) + rec = cur.fetchone() + assert rec[0] is True, type + assert rec[1] == val + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '2000-01-01 01:02:03.123456-10:20'::timestamptz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timestamptz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + + overflow_samples = [ + ("-infinity", "timestamp too small"), + ("1000-01-01 12:00+00 BC", "timestamp too small"), + ("10000-01-01 12:00+00", "timestamp too large"), + ("infinity", "timestamp too large"), + ] + + @pytest.mark.parametrize("datestyle_out", datestyles_out) + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message(self, conn, datestyle_out, val, msg): + cur = conn.cursor() + cur.execute(f"set datestyle = {datestyle_out}, YMD") + cur.execute("select %s::timestamptz", (val,)) + if datestyle_out == "ISO": + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + else: + with pytest.raises(NotImplementedError): + cur.fetchone()[0] + + @pytest.mark.parametrize("val, msg", overflow_samples) + def test_overflow_message_binary(self, conn, val, msg): + cur = conn.cursor(binary=True) + cur.execute("select %s::timestamptz", (val,)) + with pytest.raises(DataError) as excinfo: + cur.fetchone()[0] + assert msg in str(excinfo.value) + + @pytest.mark.parametrize( + "valname, tzval, tzname", + [ + ("max", "-06", "America/Chicago"), + ("min", "+09:18:59", "Asia/Tokyo"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_max_with_timezone(self, conn, fmt_out, valname, tzval, tzname): + # This happens e.g. in Django when it caches forever. + # e.g. see Django test cache.tests.DBCacheTests.test_forever_timeout + val = getattr(dt.datetime, valname).replace(microsecond=0) + tz = dt.timezone(as_tzoffset(tzval)) + want = val.replace(tzinfo=tz) + + conn.execute("set timezone to '%s'" % tzname) + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::timestamptz", [str(val) + tzval]) + got = cur.fetchone()[0] + + assert got == want + + extra = "1 day" if valname == "max" else "-1 day" + with pytest.raises(DataError): + cur.execute( + "select %s::timestamptz + %s::interval", + [str(val) + tzval, extra], + ) + got = cur.fetchone()[0] + + +class TestTime: + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "00:00"), + ("10,20,30,40", "10:20:30.000040"), + ("max", "23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_time(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute( + f""" + select '{expr}'::time = %(val){fmt_in.value}, + '{expr}'::time::text, %(val){fmt_in.value}::text + """, + {"val": as_time(val)}, + ) + ok, want, got = cur.fetchone() + assert ok, (got, want) + + @pytest.mark.parametrize( + "val, expr", + [ + ("min", "00:00"), + ("1,2", "01:02"), + ("10,20", "10:20"), + ("10,20,30", "10:20:30"), + ("10,20,30,40", "10:20:30.000040"), + ("max", "23:59:59.999999"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_time(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::time") + assert cur.fetchone()[0] == as_time(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_time_24(self, conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select '24:00'::time") + with pytest.raises(DataError): + cur.fetchone()[0] + + +class TestTimeTz: + @pytest.mark.parametrize( + "val, expr", + [ + ("min~-10", "00:00-10:00"), + ("min~+12", "00:00+12:00"), + ("10,20,30,40~-2", "10:20:30.000040-02:00"), + ("10,20,30,40~0", "10:20:30.000040Z"), + ("10,20,30,40~+2:30", "10:20:30.000040+02:30"), + ("max~-12", "23:59:59.999999-12:00"), + ("max~+12", "23:59:59.999999+12:00"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_timetz(self, conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute("set timezone to '-02:00'") + cur.execute(f"select '{expr}'::timetz = %{fmt_in.value}", (as_time(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr, timezone", + [ + ("0,0~-12", "00:00", "12:00"), + ("0,0~12", "00:00", "-12:00"), + ("3,4,5,6~2", "03:04:05.000006", "-02:00"), + ("3,4,5,6~7:8", "03:04:05.000006", "-07:08"), + ("3,0,0,456789~2", "03:00:00.456789", "-02:00"), + ("3,0,0,456789~-2", "03:00:00.456789", "+02:00"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_timetz(self, conn, val, timezone, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"set timezone to '{timezone}'") + cur.execute(f"select '{expr}'::timetz") + assert cur.fetchone()[0] == as_time(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_timetz_24(self, conn, fmt_out): + cur = conn.cursor() + cur.execute("select '24:00'::timetz") + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.parametrize( + "val, type", + [ + ("3,4,5,6", "time"), + ("3,4,5,6~0", "timetz"), + ("3,4,5,6~2", "timetz"), + ], + ) + @pytest.mark.parametrize("fmt_in", PyFormat) + def test_dump_time_tz_or_not_tz(self, conn, val, type, fmt_in): + val = as_time(val) + cur = conn.cursor() + cur.execute( + f""" + select pg_typeof(%{fmt_in.value})::regtype = %s::regtype, %{fmt_in.value} + """, + [val, type, val], + ) + rec = cur.fetchone() + assert rec[0] is True, type + assert rec[1] == val + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '01:02:03.123456-10:20'::timetz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timetz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.time(1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + + +class TestInterval: + dump_timedelta_samples = [ + ("min", "-999999999 days"), + ("1d", "1 day"), + pytest.param("-1d", "-1 day", marks=crdb_skip_negative_interval), + ("1s", "1 s"), + pytest.param("-1s", "-1 s", marks=crdb_skip_negative_interval), + pytest.param("-1m", "-0.000001 s", marks=crdb_skip_negative_interval), + ("1m", "0.000001 s"), + ("max", "999999999 days 23:59:59.999999"), + ] + + @pytest.mark.parametrize("val, expr", dump_timedelta_samples) + @pytest.mark.parametrize("intervalstyle", intervalstyles) + def test_dump_interval(self, conn, val, expr, intervalstyle): + cur = conn.cursor() + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval = %t", (as_td(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize("val, expr", dump_timedelta_samples) + def test_dump_interval_binary(self, conn, val, expr): + cur = conn.cursor() + cur.execute(f"select '{expr}'::interval = %b", (as_td(val),)) + assert cur.fetchone()[0] is True + + @pytest.mark.parametrize( + "val, expr", + [ + ("1s", "1 sec"), + ("-1s", "-1 sec"), + ("60s", "1 min"), + ("3600s", "1 hour"), + ("1s,1000m", "1.001 sec"), + ("1s,1m", "1.000001 sec"), + ("1d", "1 day"), + ("-10d", "-10 day"), + ("1d,1s,1m", "1 day 1.000001 sec"), + ("-86399s,-999999m", "-23:59:59.999999"), + ("-3723s,-400000m", "-1:2:3.4"), + ("3723s,400000m", "1:2:3.4"), + ("86399s,999999m", "23:59:59.999999"), + ("30d", "30 day"), + ("365d", "1 year"), + ("-365d", "-1 year"), + ("-730d", "-2 years"), + ("1460d", "4 year"), + ("30d", "1 month"), + ("-30d", "-1 month"), + ("60d", "2 month"), + ("-90d", "-3 month"), + ], + ) + @pytest.mark.parametrize("fmt_out", pq.Format) + def test_load_interval(self, conn, val, expr, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + @crdb_skip_datestyle + @pytest.mark.xfail # weird interval outputs + @pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")]) + @pytest.mark.parametrize( + "intervalstyle", + ["sql_standard", "postgres_verbose", "iso_8601"], + ) + def test_load_interval_intervalstyle(self, conn, val, expr, intervalstyle): + cur = conn.cursor(binary=False) + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + @pytest.mark.parametrize("fmt_out", pq.Format) + @pytest.mark.parametrize("val", ["min", "max"]) + def test_load_interval_overflow(self, conn, val, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute( + "select %s + %s * '1s'::interval", + (as_td(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + @pytest.mark.crdb_skip("copy") + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '1 days +00:00:01.000001'::interval, + 'foo bar'::text + ) to stdout + """ + ) as copy: + copy.set_types(["interval", "text"]) + rec = copy.read_row() + + want = dt.timedelta(days=1, seconds=1, microseconds=1) + assert rec[0] == want + assert rec[1] == "foo bar" + + +# +# Support +# + + +def as_date(s): + return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s) + + +def as_time(s): + if "~" in s: + s, off = s.split("~") + else: + off = None + + if "," in s: + rv = dt.time(*map(int, s.split(","))) # type: ignore[arg-type] + else: + rv = getattr(dt.time, s) + if off: + rv = rv.replace(tzinfo=as_tzinfo(off)) + + return rv + + +def as_dt(s): + if "~" not in s: + return as_naive_dt(s) + + s, off = s.split("~") + rv = as_naive_dt(s) + off = as_tzoffset(off) + rv = (rv - off).replace(tzinfo=dt.timezone.utc) + return rv + + +def as_naive_dt(s): + if "," in s: + rv = dt.datetime(*map(int, s.split(","))) # type: ignore[arg-type] + else: + rv = getattr(dt.datetime, s) + + return rv + + +def as_tzoffset(s): + if s.startswith("-"): + mul = -1 + s = s[1:] + else: + mul = 1 + + fields = ("hours", "minutes", "seconds") + return mul * dt.timedelta(**dict(zip(fields, map(int, s.split(":"))))) + + +def as_tzinfo(s): + off = as_tzoffset(s) + return dt.timezone(off) + + +def as_td(s): + if s in ("min", "max"): + return getattr(dt.timedelta, s) + + suffixes = {"d": "days", "s": "seconds", "m": "microseconds"} + kwargs = {} + for part in s.split(","): + kwargs[suffixes[part[-1]]] = int(part[:-1]) + + return dt.timedelta(**kwargs) diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py new file mode 100644 index 0000000..8dfb6d4 --- /dev/null +++ b/tests/types/test_enum.py @@ -0,0 +1,363 @@ +from enum import Enum, auto + +import pytest + +from psycopg import pq, sql, errors as e +from psycopg.adapt import PyFormat +from psycopg.types import TypeInfo +from psycopg.types.enum import EnumInfo, register_enum + +from ..fix_crdb import crdb_encoding + + +class PureTestEnum(Enum): + FOO = auto() + BAR = auto() + BAZ = auto() + + +class StrTestEnum(str, Enum): + ONE = "ONE" + TWO = "TWO" + THREE = "THREE" + + +NonAsciiEnum = Enum( + "NonAsciiEnum", + {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"}, + type=str, +) + + +class IntTestEnum(int, Enum): + ONE = 1 + TWO = 2 + THREE = 3 + + +enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum] +encodings = ["utf8", crdb_encoding("latin1")] + + +@pytest.fixture(scope="session", autouse=True) +def make_test_enums(request, svcconn): + for enum in enum_cases + [NonAsciiEnum]: + ensure_enum(enum, svcconn) + + +def ensure_enum(enum, conn): + name = enum.__name__.lower() + labels = list(enum.__members__) + conn.execute( + sql.SQL( + """ + drop type if exists {name}; + create type {name} as enum ({labels}); + """ + ).format(name=sql.Identifier(name), labels=sql.SQL(",").join(labels)) + ) + return name, enum, labels + + +def test_fetch_info(conn): + info = EnumInfo.fetch(conn, "StrTestEnum") + assert info.name == "strtestenum" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.labels) == len(StrTestEnum) + assert info.labels == list(StrTestEnum.__members__) + + +@pytest.mark.asyncio +async def test_fetch_info_async(aconn): + info = await EnumInfo.fetch(aconn, "PureTestEnum") + assert info.name == "puretestenum" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.labels) == len(PureTestEnum) + assert info.labels == list(PureTestEnum.__members__) + + +def test_register_makes_a_type(conn): + info = EnumInfo.fetch(conn, "IntTestEnum") + assert info + assert info.enum is None + register_enum(info, context=conn) + assert info.enum is not None + assert [e.name for e in info.enum] == list(IntTestEnum.__members__) + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum=enum) + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out): + enum = NonAsciiEnum + conn.execute(f"set client_encoding to {encoding}") + + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum=enum) + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + conn.execute("set client_encoding to sql_ascii") + + for label in info.labels: + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}", [label], binary=fmt_out + ) + assert cur.fetchone()[0] == enum[label] + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out): + enum = NonAsciiEnum + conn.execute(f"set client_encoding to {encoding}") + + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + conn.execute("set client_encoding to sql_ascii") + + for item in enum: + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_dumper(conn, enum, fmt_in, fmt_out): + for item in enum: + if enum is PureTestEnum: + want = item.name + else: + want = item.value + + cur = conn.execute(f"select %{fmt_in.value}", [item], binary=fmt_out) + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") + for item in NonAsciiEnum: + cur = conn.execute(f"select %{fmt_in.value}", [item.value], binary=fmt_out) + assert cur.fetchone()[0] == item.value + + +@pytest.mark.parametrize("enum", enum_cases) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_loader(conn, enum, fmt_in, fmt_out): + for label in enum.__members__: + cur = conn.execute( + f"select %{fmt_in.value}::{enum.__name__}", [label], binary=fmt_out + ) + want = enum[label].name + if fmt_out == pq.Format.BINARY: + want = want.encode() + assert cur.fetchone()[0] == want + + +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") + + for label in NonAsciiEnum.__members__: + cur = conn.execute( + f"select %{fmt_in.value}::nonasciienum", [label], binary=fmt_out + ) + if fmt_out == pq.Format.TEXT: + assert cur.fetchone()[0] == label + else: + assert cur.fetchone()[0] == label.encode(encoding) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_loader(conn, fmt_in, fmt_out): + enum = PureTestEnum + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + labels = list(enum.__members__) + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out + ) + assert cur.fetchone()[0] == list(enum) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_dumper(conn, fmt_in, fmt_out): + enum = StrTestEnum + info = EnumInfo.fetch(conn, enum.__name__) + register_enum(info, conn, enum) + + cur = conn.execute(f"select %{fmt_in.value}::text[]", [list(enum)], binary=fmt_out) + assert cur.fetchone()[0] == list(enum.__members__) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_array_loader(conn, fmt_in, fmt_out): + enum = IntTestEnum + info = TypeInfo.fetch(conn, enum.__name__) + info.register(conn) + labels = list(enum.__members__) + cur = conn.execute( + f"select %{fmt_in.value}::{info.name}[]", [labels], binary=fmt_out + ) + if fmt_out == pq.Format.TEXT: + assert cur.fetchone()[0] == labels + else: + assert cur.fetchone()[0] == [item.encode() for item in labels] + + +def test_enum_error(conn): + conn.autocommit = True + + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum) + + with pytest.raises(e.DataError): + conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone() + with pytest.raises(e.DataError): + conn.execute("select 'BAR'::puretestenum").fetchone() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "mapping", + [ + {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"}, + [ + (StrTestEnum.ONE, "FOO"), + (StrTestEnum.TWO, "BAR"), + (StrTestEnum.THREE, "BAZ"), + ], + ], +) +def test_remap(conn, fmt_in, fmt_out, mapping): + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum, mapping=mapping) + + for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]: + cur = conn.execute(f"select %{fmt_in.value}::text", [StrTestEnum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out) + assert cur.fetchone()[0] is StrTestEnum[member] + + +def test_remap_rename(conn): + enum = Enum("RenamedEnum", "FOO BAR QUX") + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"}) + + for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_python(conn): + enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]} + register_enum(info, conn, enum, mapping=mapping) + + for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + + for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_postgres(conn): + enum = Enum("SmallerEnum", "FOO") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")] + register_enum(info, conn, enum, mapping=mapping) + + cur = conn.execute("select %s::text", [enum.FOO]) + assert cur.fetchone()[0] == "BAZ" + + for label in PureTestEnum.__members__: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum.FOO + + +def test_remap_by_value(conn): + enum = Enum( # type: ignore + "ByValue", + {m.lower(): m for m in PureTestEnum.__members__}, + ) + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={m: m.value for m in enum}) + + for label in PureTestEnum.__members__: + cur = conn.execute("select %s::text", [enum[label.lower()]]) + assert cur.fetchone()[0] == label + + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[label.lower()] diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py new file mode 100644 index 0000000..5142d58 --- /dev/null +++ b/tests/types/test_hstore.py @@ -0,0 +1,107 @@ +import pytest + +import psycopg +from psycopg.types import TypeInfo +from psycopg.types.hstore import HstoreLoader, register_hstore + +pytestmark = pytest.mark.crdb_skip("hstore") + + +@pytest.mark.parametrize( + "s, d", + [ + ("", {}), + ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}), + ('"a" => "1" , "b" => "2"', {"a": "1", "b": "2"}), + ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}), + (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}), + ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}), + ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}), + (r'"a\\"=>"1"', {"a\\": "1"}), + (r'"a\""=>"1"', {'a"': "1"}), + (r'"a\\\""=>"1"', {r"a\"": "1"}), + (r'"a\\\\\""=>"1"', {r'a\\"': "1"}), + ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}), + ], +) +def test_parse_ok(s, d): + loader = HstoreLoader(0, None) + assert loader.load(s.encode()) == d + + +@pytest.mark.parametrize( + "s", + [ + "a", + '"a"', + r'"a\\""=>"1"', + r'"a\\\\""=>"1"', + '"a=>"1"', + '"a"=>"1", "b"=>NUL', + ], +) +def test_parse_bad(s): + with pytest.raises(psycopg.DataError): + loader = HstoreLoader(0, None) + loader.load(s.encode()) + + +def test_register_conn(hstore, conn): + info = TypeInfo.fetch(conn, "hstore") + register_hstore(info, conn) + assert conn.adapters.types[info.oid].name == "hstore" + + cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + + +def test_register_curs(hstore, conn): + info = TypeInfo.fetch(conn, "hstore") + cur = conn.cursor() + register_hstore(info, cur) + assert conn.adapters.types.get(info.oid) is None + assert cur.adapters.types[info.oid].name == "hstore" + + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + + +def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters): + info = TypeInfo.fetch(svcconn, "hstore") + register_hstore(info) + assert psycopg.adapters.types[info.oid].name == "hstore" + + assert svcconn.adapters.types.get(info.oid) is None + conn = conn_cls.connect(dsn) + assert conn.adapters.types[info.oid].name == "hstore" + + cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + assert cur.fetchone() == (None, {}, {"a": "b"}) + conn.close() + + +ab = list(map(chr, range(32, 128))) +samp = [ + {}, + {"a": "b", "c": None}, + dict(zip(ab, ab)), + {"".join(ab): "".join(ab)}, +] + + +@pytest.mark.parametrize("d", samp) +def test_roundtrip(hstore, conn, d): + register_hstore(TypeInfo.fetch(conn, "hstore"), conn) + d1 = conn.execute("select %s", [d]).fetchone()[0] + assert d == d1 + + +def test_roundtrip_array(hstore, conn): + register_hstore(TypeInfo.fetch(conn, "hstore"), conn) + samp1 = conn.execute("select %s", (samp,)).fetchone()[0] + assert samp1 == samp + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="hstore.*extension"): + register_hstore(None, conn) # type: ignore[arg-type] diff --git a/tests/types/test_json.py b/tests/types/test_json.py new file mode 100644 index 0000000..50e8ce3 --- /dev/null +++ b/tests/types/test_json.py @@ -0,0 +1,182 @@ +import json +from copy import deepcopy + +import pytest + +import psycopg.types +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat +from psycopg.types.json import set_json_dumps, set_json_loads + +samples = [ + "null", + "true", + '"te\'xt"', + '"\\u00e0\\u20ac"', + "123", + "123.45", + '["a", 100]', + '{"a": 100}', +] + + +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_wrapper_regtype(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + cur = conn.cursor() + cur.execute( + f"select pg_typeof(%{fmt_in.value})::regtype = %s::regtype", + (wrapper([]), wrapper.__name__.lower()), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump(conn, val, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = json.loads(val) + cur = conn.cursor() + cur.execute( + f"select %{fmt_in.value}::text = %s::{wrapper.__name__.lower()}::text", + (wrapper(obj), val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.crdb_skip("json array") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_array_dump(conn, val, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = json.loads(val) + cur = conn.cursor() + cur.execute( + f"select %{fmt_in.value}::text = array[%s::{wrapper.__name__.lower()}]::text", + ([wrapper(obj)], val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load(conn, val, jtype, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{jtype}", (val,)) + assert cur.fetchone()[0] == json.loads(val) + + +@pytest.mark.crdb_skip("json array") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_array(conn, val, jtype, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select array[%s::{jtype}]", (val,)) + assert cur.fetchone()[0] == [json.loads(val)] + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("val", samples) +@pytest.mark.parametrize("jtype", ["json", "jsonb"]) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_copy(conn, val, jtype, fmt_out): + cur = conn.cursor() + stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format( + val, sql.Identifier(jtype), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([jtype]) + (got,) = copy.read_row() + + assert got == json.loads(val) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur = conn.cursor() + + set_json_dumps(my_dumps) + try: + cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj),)) + assert cur.fetchone()[0] is True + finally: + set_json_dumps(json.dumps) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise_context(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur1 = conn.cursor() + cur2 = conn.cursor() + + set_json_dumps(my_dumps, cur2) + cur1.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),)) + assert cur1.fetchone()[0] is None + cur2.execute(f"select %{fmt_in.value}->>'baz'", (wrapper(obj),)) + assert cur2.fetchone()[0] == "qux" + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"]) +def test_dump_customise_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.json, wrapper) + obj = {"foo": "bar"} + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}->>'baz' = 'qux'", (wrapper(obj, my_dumps),)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("binary", [True, False]) +@pytest.mark.parametrize("pgtype", ["json", "jsonb"]) +def test_load_customise(conn, binary, pgtype): + cur = conn.cursor(binary=binary) + + set_json_loads(my_loads) + try: + cur.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + obj = cur.fetchone()[0] + assert obj["foo"] == "bar" + assert obj["answer"] == 42 + finally: + set_json_loads(json.loads) + + +@pytest.mark.parametrize("binary", [True, False]) +@pytest.mark.parametrize("pgtype", ["json", "jsonb"]) +def test_load_customise_context(conn, binary, pgtype): + cur1 = conn.cursor(binary=binary) + cur2 = conn.cursor(binary=binary) + + set_json_loads(my_loads, cur2) + cur1.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + got = cur1.fetchone()[0] + assert got["foo"] == "bar" + assert "answer" not in got + + cur2.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""") + got = cur2.fetchone()[0] + assert got["foo"] == "bar" + assert got["answer"] == 42 + + +def my_dumps(obj): + obj = deepcopy(obj) + obj["baz"] = "qux" + return json.dumps(obj) + + +def my_loads(data): + obj = json.loads(data) + obj["answer"] = 42 + return obj diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py new file mode 100644 index 0000000..2ab5152 --- /dev/null +++ b/tests/types/test_multirange.py @@ -0,0 +1,434 @@ +import pickle +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg import pq, sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg.types.range import Range +from psycopg.types import multirange +from psycopg.types.multirange import Multirange, MultirangeInfo +from psycopg.types.multirange import register_multirange + +from ..utils import eur +from .test_range import create_test_range + +pytestmark = [ + pytest.mark.pg(">= 14"), + pytest.mark.crdb_skip("range"), +] + + +class TestMultirangeObject: + def test_empty(self): + mr = Multirange[int]() + assert not mr + assert len(mr) == 0 + + mr = Multirange([]) + assert not mr + assert len(mr) == 0 + + def test_sequence(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + assert mr + assert len(mr) == 3 + assert mr[2] == Range(50, 60) + assert mr[-2] == Range(30, 40) + + def test_bad_type(self): + with pytest.raises(TypeError): + Multirange(Range(10, 20)) # type: ignore[arg-type] + + with pytest.raises(TypeError): + Multirange([10]) # type: ignore[arg-type] + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + + with pytest.raises(TypeError): + mr[0] = "foo" # type: ignore[call-overload] + + with pytest.raises(TypeError): + mr[0:1] = "foo" # type: ignore[assignment] + + with pytest.raises(TypeError): + mr[0:1] = ["foo"] # type: ignore[list-item] + + with pytest.raises(TypeError): + mr.insert(0, "foo") # type: ignore[arg-type] + + def test_setitem(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1] = Range(31, 41) + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + + def test_setitem_slice(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1:3] = [Range(31, 41), Range(51, 61)] + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)]) + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + with pytest.raises(TypeError, match="can only assign an iterable"): + mr[1:3] = Range(31, 41) # type: ignore[call-overload] + + mr[1:3] = [Range(31, 41)] + assert mr == Multirange([Range(10, 20), Range(31, 41)]) + + def test_delitem(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + del mr[1] + assert mr == Multirange([Range(10, 20), Range(50, 60)]) + + del mr[-2] + assert mr == Multirange([Range(50, 60)]) + + def test_insert(self): + mr = Multirange([Range(10, 20), Range(50, 60)]) + mr.insert(1, Range(31, 41)) + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + + def test_relations(self): + mr1 = Multirange([Range(10, 20), Range(30, 40)]) + mr2 = Multirange([Range(11, 20), Range(30, 40)]) + mr3 = Multirange([Range(9, 20), Range(30, 40)]) + assert mr1 <= mr1 + assert not mr1 < mr1 + assert mr1 >= mr1 + assert not mr1 > mr1 + assert mr1 < mr2 + assert mr1 <= mr2 + assert mr1 > mr3 + assert mr1 >= mr3 + assert mr1 != mr2 + assert not mr1 == mr2 + + def test_pickling(self): + r = Multirange([Range(0, 4)]) + assert pickle.loads(pickle.dumps(r)) == r + + def test_str(self): + mr = Multirange([Range(10, 20), Range(30, 40)]) + assert str(mr) == "{[10, 20), [30, 40)}" + + def test_repr(self): + mr = Multirange([Range(10, 20), Range(30, 40)]) + expected = "Multirange([Range(10, 20, '[)'), Range(30, 40, '[)')])" + assert repr(mr) == expected + + +tzinfo = dt.timezone(dt.timedelta(hours=2)) + +samples = [ + ("int4multirange", [Range(None, None, "()")]), + ("int4multirange", [Range(10, 20), Range(30, 40)]), + ("int8multirange", [Range(None, None, "()")]), + ("int8multirange", [Range(10, 20), Range(30, 40)]), + ( + "nummultirange", + [ + Range(None, Decimal(-100)), + Range(Decimal(100), Decimal("100.123")), + ], + ), + ( + "datemultirange", + [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))], + ), + ( + "tsmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999), + ) + ], + ), + ( + "tstzmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + Range( + dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + ], + ), +] + +mr_names = """ + int4multirange int8multirange nummultirange + datemultirange tsmultirange tstzmultirange""".split() + +mr_classes = """ + Int4Multirange Int8Multirange NumericMultirange + DateMultirange TimestampMultirange TimestamptzMultirange""".split() + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty(conn, pgtype, fmt_in): + mr = Multirange() # type: ignore[var-annotated] + cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in.value}", (mr,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + dumper = getattr(multirange, wrapper + "Dumper") + wrapper = getattr(multirange, wrapper) + mr = wrapper() + rec = conn.execute( + f""" + select '{{}}' = %(mr){fmt_in.value}, + %(mr){fmt_in.value}::text, + pg_typeof(%(mr){fmt_in.value})::oid + """, + {"mr": mr}, + ).fetchone() + assert rec[0] is True, rec[1] + assert rec[2] == dumper.oid + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize( + "fmt_in", + [ + PyFormat.AUTO, + PyFormat.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + PyFormat.BINARY, + marks=pytest.mark.xfail( + reason="can't dump array of untypes binary multirange without cast" + ), + ), + ], +) +def test_dump_builtin_array(conn, pgtype, fmt_in): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in.value}", + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f""" + select array['{{}}'::{pgtype}, + '{{(,)}}'::{pgtype}] = %{fmt_in.value}::{pgtype}[] + """, + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(multirange, wrapper) + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.execute( + f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in.value}""", ([mr1, mr2],) + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in.value}", ranges + [mr]) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_empty(conn, pgtype, fmt_out): + mr = Multirange() # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone() + assert type(got) is Multirange + assert got == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_array(conn, pgtype, fmt_out): + mr1 = Multirange() # type: ignore[var-annotated] + mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]" + ).fetchone() + assert got == [mr1, mr2] + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_range(conn, pgtype, ranges, fmt_out): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select {pgtype}({phs})", ranges) + assert cur.fetchone()[0] == mr + + +@pytest.mark.parametrize( + "min, max, bounds", + [ + ("2000,1,1", "2001,1,1", "[)"), + ("2000,1,1", None, "[)"), + (None, "2001,1,1", "()"), + (None, None, "()"), + (None, None, "empty"), + ], +) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in(conn, min, max, bounds, format): + cur = conn.cursor() + cur.execute("create table copymr (id serial primary key, mr datemultirange)") + + if bounds != "empty": + min = dt.date(*map(int, min.split(","))) if min else None + max = dt.date(*map(int, max.split(","))) if max else None + r = Range[dt.date](min, max, bounds) + else: + r = Range(empty=True) + + mr = Multirange([r]) + try: + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.write_row([mr]) + except e.InternalError_: + if not min and not max and format == pq.Format.BINARY: + pytest.xfail("TODO: add annotation to dump multirange with no type info") + else: + raise + + rec = cur.execute("select mr from copymr order by id").fetchone() + if not r.isempty: + assert rec[0] == mr + else: + assert rec[0] == Multirange() + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_wrappers(conn, wrapper, format): + cur = conn.cursor() + cur.execute("create table copymr (id serial primary key, mr datemultirange)") + + mr = getattr(multirange, wrapper)() + + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_set_type(conn, pgtype, format): + cur = conn.cursor() + cur.execute(f"create table copymr (id serial primary key, mr {pgtype})") + + mr = Multirange() # type: ignore[var-annotated] + + with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy: + copy.set_types([pgtype]) + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.fixture(scope="session") +def testmr(svcconn): + create_test_range(svcconn) + + +fetch_cases = [ + ("testmultirange", "text"), + ("testschema.testmultirange", "float8"), + (sql.Identifier("testmultirange"), "text"), + (sql.Identifier("testschema", "testmultirange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testmr, name, subtype): + info = MultirangeInfo.fetch(conn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == conn.adapters.types[subtype].oid + + +def test_fetch_info_not_found(conn): + assert MultirangeInfo.fetch(conn, "nosuchrange") is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testmr, name, subtype): # noqa: F811 + info = await MultirangeInfo.fetch(aconn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == aconn.adapters.types[subtype].oid + + +@pytest.mark.asyncio +async def test_fetch_info_not_found_async(aconn): + assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None + + +def test_dump_custom_empty(conn, testmr): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + r = Multirange() # type: ignore[var-annotated] + cur = conn.execute("select '{}'::testmultirange = %s", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_custom_empty(conn, testmr, fmt_out): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select '{}'::testmultirange").fetchone() + assert isinstance(got, Multirange) + assert not got + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute(f'create type "{name}" as range (subtype = text)') + info = MultirangeInfo.fetch(conn, f'"{name}_multirange"') + register_multirange(info, conn) + obj = Multirange([Range("a", "z", "[]")]) + assert sql.Literal(obj).as_string(conn) == f"'{{[a,z]}}'::\"{name}_multirange\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + assert cur.fetchone()[0] == obj diff --git a/tests/types/test_net.py b/tests/types/test_net.py new file mode 100644 index 0000000..8739398 --- /dev/null +++ b/tests/types/test_net.py @@ -0,0 +1,135 @@ +import ipaddress + +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat + +crdb_skip_cidr = pytest.mark.crdb_skip("cidr") + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"]) +def test_address_dump(conn, fmt_in, val): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::inet", (ipaddress.ip_address(val), val)) + assert cur.fetchone()[0] is True + cur.execute( + f"select %{fmt_in.value} = array[null, %s]::inet[]", + ([None, ipaddress.ip_interface(val)], val), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/128"]) +def test_interface_dump(conn, fmt_in, val): + cur = conn.cursor() + rec = cur.execute( + f"select %(val){fmt_in.value} = %(repr)s::inet," + f" %(val){fmt_in.value}, %(repr)s::inet", + {"val": ipaddress.ip_interface(val), "repr": val}, + ).fetchone() + assert rec[0] is True, f"{rec[1]} != {rec[2]}" + cur.execute( + f"select %{fmt_in.value} = array[null, %s]::inet[]", + ([None, ipaddress.ip_interface(val)], val), + ) + assert cur.fetchone()[0] is True + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"]) +def test_network_dump(conn, fmt_in, val): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::cidr", (ipaddress.ip_network(val), val)) + assert cur.fetchone()[0] is True + cur.execute( + f"select %{fmt_in.value} = array[NULL, %s]::cidr[]", + ([None, ipaddress.ip_network(val)], val), + ) + assert cur.fetchone()[0] is True + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_network_mixed_size_array(conn, fmt_in): + val = [ + ipaddress.IPv4Network("192.168.0.1/32"), + ipaddress.IPv6Network("::1/128"), + ] + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}", (val,)) + got = cur.fetchone()[0] + assert val == got + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"]) +def test_inet_load_address(conn, fmt_out, val): + addr = ipaddress.ip_address(val.split("/", 1)[0]) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::inet", (val,)) + assert cur.fetchone()[0] == addr + + cur.execute("select array[null, %s::inet]", (val,)) + assert cur.fetchone()[0] == [None, addr] + + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == addr + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"]) +def test_inet_load_network(conn, fmt_out, val): + pyval = ipaddress.ip_interface(val) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::inet", (val,)) + assert cur.fetchone()[0] == pyval + + cur.execute("select array[null, %s::inet]", (val,)) + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["inet"]) + (got,) = copy.read_row() + + assert got == pyval + + +@crdb_skip_cidr +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"]) +def test_cidr_load(conn, fmt_out, val): + pyval = ipaddress.ip_network(val) + cur = conn.cursor(binary=fmt_out) + + cur.execute("select %s::cidr", (val,)) + assert cur.fetchone()[0] == pyval + + cur.execute("select array[null, %s::cidr]", (val,)) + assert cur.fetchone()[0] == [None, pyval] + + stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["cidr"]) + (got,) = copy.read_row() + + assert got == pyval diff --git a/tests/types/test_none.py b/tests/types/test_none.py new file mode 100644 index 0000000..4c008fd --- /dev/null +++ b/tests/types/test_none.py @@ -0,0 +1,12 @@ +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat + + +def test_quote_none(conn): + + tx = Transformer() + assert tx.get_dumper(None, PyFormat.TEXT).quote(None) == b"NULL" + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None))) + assert cur.fetchone()[0] is None diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py new file mode 100644 index 0000000..a27bc84 --- /dev/null +++ b/tests/types/test_numeric.py @@ -0,0 +1,625 @@ +import enum +from decimal import Decimal +from math import isnan, isinf, exp + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg.adapt import Transformer, PyFormat +from psycopg.types.numeric import FloatLoader + +from ..fix_crdb import is_crdb + +# +# Tests with int +# + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::int"), + (1, "'1'::int"), + (-1, "'-1'::int"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), + (int(2**63 - 1), "'9223372036854775807'::bigint"), + (int(-(2**63)), "'-9223372036854775808'::bigint"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_int(conn, val, expr, fmt_in): + assert isinstance(val, int) + cur = conn.cursor() + cur.execute(f"select {expr} = %{fmt_in.value}", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, "'0'::smallint"), + (1, "'1'::smallint"), + (-1, "'-1'::smallint"), + (42, "'42'::smallint"), + (-42, "'-42'::smallint"), + (int(2**15 - 1), f"'{2 ** 15 - 1}'::smallint"), + (int(-(2**15)), f"'{-2 ** 15}'::smallint"), + (int(2**15), f"'{2 ** 15}'::integer"), + (int(-(2**15) - 1), f"'{-2 ** 15 - 1}'::integer"), + (int(2**31 - 1), f"'{2 ** 31 - 1}'::integer"), + (int(-(2**31)), f"'{-2 ** 31}'::integer"), + (int(2**31), f"'{2 ** 31}'::bigint"), + (int(-(2**31) - 1), f"'{-2 ** 31 - 1}'::bigint"), + (int(2**63 - 1), f"'{2 ** 63 - 1}'::bigint"), + (int(-(2**63)), f"'{-2 ** 63}'::bigint"), + (int(2**63), f"'{2 ** 63}'::numeric"), + (int(-(2**63) - 1), f"'{-2 ** 63 - 1}'::numeric"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_int_subtypes(conn, val, expr, fmt_in): + cur = conn.cursor() + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in.value})", (val,)) + assert cur.fetchone()[0] is True + cur.execute( + f"select {expr} = %(v){fmt_in.value}, {expr}::text, %(v){fmt_in.value}::text", + {"v": val}, + ) + ok, want, got = cur.fetchone() + assert got == want + assert ok + + +class MyEnum(enum.IntEnum): + foo = 42 + + +class MyMixinEnum(enum.IntEnum): + foo = 42000000 + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("enum", [MyEnum, MyMixinEnum]) +def test_dump_enum(conn, fmt_in, enum): + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value}", (enum.foo,)) + (res,) = cur.fetchone() + assert res == enum.foo.value + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, b"0"), + (1, b"1"), + (-1, b" -1"), + (42, b"42"), + (-42, b" -42"), + (int(2**63 - 1), b"9223372036854775807"), + (int(-(2**63)), b" -9223372036854775808"), + (int(2**63), b"9223372036854775808"), + (int(-(2**63 + 1)), b" -9223372036854775809"), + (int(2**100), b"1267650600228229401496703205376"), + (int(-(2**100)), b" -1267650600228229401496703205376"), + ], +) +def test_quote_int(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + assert cur.fetchone() == (val, -val) + + +@pytest.mark.parametrize( + "val, pgtype, want", + [ + ("0", "integer", 0), + ("1", "integer", 1), + ("-1", "integer", -1), + ("0", "int2", 0), + ("0", "int4", 0), + ("0", "int8", 0), + ("0", "integer", 0), + ("0", "oid", 0), + # bounds + ("-32768", "smallint", -32768), + ("+32767", "smallint", 32767), + ("-2147483648", "integer", -2147483648), + ("+2147483647", "integer", 2147483647), + ("-9223372036854775808", "bigint", -9223372036854775808), + ("9223372036854775807", "bigint", 9223372036854775807), + ("4294967295", "oid", 4294967295), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_int(conn, val, pgtype, want, fmt_out): + if pgtype == "integer" and is_crdb(conn): + pgtype = "int4" # "integer" is "int8" on crdb + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{pgtype}", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid + result = cur.fetchone()[0] + assert result == want + assert type(result) is type(want) + + # arrays work too + cur.execute(f"select array[%s::{pgtype}]", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid + result = cur.fetchone()[0] + assert result == [want] + assert type(result[0]) is type(want) + + +# +# Tests with float +# + + +@pytest.mark.parametrize( + "val, expr", + [ + (0.0, "'0'"), + (1.0, "'1'"), + (-1.0, "'-1'"), + (float("nan"), "'NaN'"), + (float("inf"), "'Infinity'"), + (float("-inf"), "'-Infinity'"), + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_float(conn, val, expr, fmt_in): + assert isinstance(val, float) + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = {expr}::float8", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0.0, b"0.0"), + (1.0, b"1.0"), + (10000000000000000.0, b"1e+16"), + (1000000.1, b"1000000.1"), + (-100000.000001, b" -100000.000001"), + (-1.0, b" -1.0"), + (float("nan"), b"'NaN'::float8"), + (float("inf"), b"'Infinity'::float8"), + (float("-inf"), b"'-Infinity'::float8"), + ], +) +def test_quote_float(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + if isnan(val): + assert isnan(r[0]) and isnan(r[1]) + else: + if isinstance(r[0], Decimal): + r = tuple(map(float, r)) + + assert r == (val, -val) + + +@pytest.mark.parametrize( + "val, expr", + [ + (exp(1), "exp(1.0)"), + (-exp(1), "-exp(1.0)"), + (1e30, "'1e30'"), + (1e-30, "1e-30"), + (-1e30, "'-1e30'"), + (-1e-30, "-1e-30"), + ], +) +def test_dump_float_approx(conn, val, expr): + assert isinstance(val, float) + cur = conn.cursor() + cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)) + assert cur.fetchone()[0] is True + + cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, pgtype, want", + [ + ("0", "float4", 0.0), + ("0.0", "float4", 0.0), + ("42", "float4", 42.0), + ("-42", "float4", -42.0), + ("0.0", "float8", 0.0), + ("0.0", "real", 0.0), + ("0.0", "double precision", 0.0), + ("0.0", "float4", 0.0), + ("nan", "float4", float("nan")), + ("inf", "float4", float("inf")), + ("-inf", "float4", -float("inf")), + ("nan", "float8", float("nan")), + ("inf", "float8", float("inf")), + ("-inf", "float8", -float("inf")), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_float(conn, val, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select %s::{pgtype}", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].oid + result = cur.fetchone()[0] + + def check(result, want): + assert type(result) is type(want) + if isnan(want): + assert isnan(result) + elif isinf(want): + assert isinf(result) + assert (result < 0) is (want < 0) + else: + assert result == want + + check(result, want) + + cur.execute(f"select array[%s::{pgtype}]", (val,)) + assert cur.pgresult.fformat(0) == fmt_out + assert cur.pgresult.ftype(0) == conn.adapters.types[pgtype].array_oid + result = cur.fetchone()[0] + assert isinstance(result, list) + check(result[0], want) + + +@pytest.mark.parametrize( + "expr, pgtype, want", + [ + ("exp(1.0)", "float4", 2.71828), + ("-exp(1.0)", "float4", -2.71828), + ("exp(1.0)", "float8", 2.71828182845905), + ("-exp(1.0)", "float8", -2.71828182845905), + ("1.42e10", "float4", 1.42e10), + ("-1.42e10", "float4", -1.42e10), + ("1.42e40", "float8", 1.42e40), + ("-1.42e40", "float8", -1.42e40), + ], +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_float_approx(conn, expr, pgtype, want, fmt_out): + cur = conn.cursor(binary=fmt_out) + cur.execute("select %s::%s" % (expr, pgtype)) + assert cur.pgresult.fformat(0) == fmt_out + result = cur.fetchone()[0] + assert result == pytest.approx(want) + + +@pytest.mark.crdb_skip("copy") +def test_load_float_copy(conn): + cur = conn.cursor(binary=False) + with cur.copy("copy (select 3.14::float8, 'hi'::text) to stdout;") as copy: + copy.set_types(["float8", "text"]) + rec = copy.read_row() + + assert rec[0] == pytest.approx(3.14) + assert rec[1] == "hi" + + +# +# Tests with decimal +# + + +@pytest.mark.parametrize( + "val", + [ + "0", + "-0", + "0.0", + "0.000000000000000000001", + "-0.000000000000000000001", + "nan", + "snan", + ], +) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_roundtrip_numeric(conn, val, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + val = Decimal(val) + cur.execute(f"select %{fmt_in.value}", (val,)) + result = cur.fetchone()[0] + assert isinstance(result, Decimal) + if val.is_nan(): + assert result.is_nan() + else: + assert result == val + + +@pytest.mark.parametrize( + "val, expr", + [ + ("0", b"0"), + ("0.0", b"0.0"), + ("0.00000000000000001", b"1E-17"), + ("-0.00000000000000001", b" -1E-17"), + ("nan", b"'NaN'::numeric"), + ("snan", b"'NaN'::numeric"), + ], +) +def test_quote_numeric(conn, val, expr): + val = Decimal(val) + tx = Transformer() + assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + + if val.is_nan(): + assert isnan(r[0]) and isnan(r[1]) + else: + assert r == (val, -val) + + +@pytest.mark.crdb_skip("binary decimal") +@pytest.mark.parametrize( + "expr", + ["NaN", "1", "1.0", "-1", "0.0", "0.01", "11", "1.1", "1.01", "0", "0.00"] + + [ + "0.0000000", + "0.00001", + "1.00001", + "-1.00000000000000", + "-2.00000000000000", + "1000000000.12345", + "100.123456790000000000000000", + "1.0e-1000", + "1e1000", + "0.000000000000000000000000001", + "1.0000000000000000000000001", + "1000000000000000000000000.001", + "1000000000000000000000000000.001", + "9999999999999999999999999999.9", + ], +) +def test_dump_numeric_binary(conn, expr): + cur = conn.cursor() + val = Decimal(expr) + cur.execute("select %b::text, %s::decimal::text", [val, expr]) + want, got = cur.fetchone() + assert got == want + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt_in", + [ + f + if f != PyFormat.BINARY + else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal")) + for f in PyFormat + ], +) +def test_dump_numeric_exhaustive(conn, fmt_in): + cur = conn.cursor() + + funcs = [ + (lambda i: "1" + "0" * i), + (lambda i: "1" + "0" * i + "." + "0" * i), + (lambda i: "-1" + "0" * i), + (lambda i: "0." + "0" * i + "1"), + (lambda i: "-0." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "10"), + (lambda i: "1" + "0" * i + ".001"), + (lambda i: "9" + "9" * i), + (lambda i: "9" + "." + "9" * i), + (lambda i: "9" + "9" * i + ".9"), + (lambda i: "9" + "9" * i + "." + "9" * i), + (lambda i: "1.1e%s" % i), + (lambda i: "1.1e-%s" % i), + ] + + for i in range(100): + for f in funcs: + expr = f(i) + val = Decimal(expr) + cur.execute(f"select %{fmt_in.value}::text, %s::decimal::text", [val, expr]) + got, want = cur.fetchone() + assert got == want + + +@pytest.mark.pg(">= 14") +@pytest.mark.parametrize( + "val, expr", + [ + ("inf", "Infinity"), + ("-inf", "-Infinity"), + ], +) +def test_dump_numeric_binary_inf(conn, val, expr): + cur = conn.cursor() + val = Decimal(val) + cur.execute("select %b", [val]) + + +@pytest.mark.parametrize( + "expr", + ["nan", "0", "1", "-1", "0.0", "0.01"] + + [ + "0.0000000", + "-1.00000000000000", + "-2.00000000000000", + "1000000000.12345", + "100.123456790000000000000000", + "1.0e-1000", + "1e1000", + "0.000000000000000000000000001", + "1.0000000000000000000000001", + "1000000000000000000000000.001", + "1000000000000000000000000000.001", + "9999999999999999999999999999.9", + ], +) +def test_load_numeric_binary(conn, expr): + cur = conn.cursor(binary=1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(expr) + if val.is_nan(): + assert res.is_nan() + else: + assert res == val + if "e" not in expr: + assert str(res) == str(val) + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_numeric_exhaustive(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + + funcs = [ + (lambda i: "1" + "0" * i), + (lambda i: "1" + "0" * i + "." + "0" * i), + (lambda i: "-1" + "0" * i), + (lambda i: "0." + "0" * i + "1"), + (lambda i: "-0." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "1"), + (lambda i: "1." + "0" * i + "10"), + (lambda i: "1" + "0" * i + ".001"), + (lambda i: "9" + "9" * i), + (lambda i: "9" + "." + "9" * i), + (lambda i: "9" + "9" * i + ".9"), + (lambda i: "9" + "9" * i + "." + "9" * i), + ] + + for i in range(100): + for f in funcs: + snum = f(i) + want = Decimal(snum) + got = cur.execute(f"select '{snum}'::decimal").fetchone()[0] + assert want == got + assert str(want) == str(got) + + +@pytest.mark.pg(">= 14") +@pytest.mark.parametrize( + "val, expr", + [ + ("inf", "Infinity"), + ("-inf", "-Infinity"), + ], +) +def test_load_numeric_binary_inf(conn, val, expr): + cur = conn.cursor(binary=1) + res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] + val = Decimal(val) + assert res == val + + +@pytest.mark.parametrize( + "val", + [ + "0", + "0.0", + "0.000000000000000000001", + "-0.000000000000000000001", + "nan", + ], +) +def test_numeric_as_float(conn, val): + cur = conn.cursor() + cur.adapters.register_loader("numeric", FloatLoader) + + val = Decimal(val) + cur.execute("select %s as val", (val,)) + result = cur.fetchone()[0] + assert isinstance(result, float) + if val.is_nan(): + assert isnan(result) + else: + assert result == pytest.approx(float(val)) + + # the customization works with arrays too + cur.execute("select %s as arr", ([val],)) + result = cur.fetchone()[0] + assert isinstance(result, list) + assert isinstance(result[0], float) + if val.is_nan(): + assert isnan(result[0]) + else: + assert result[0] == pytest.approx(float(val)) + + +# +# Mixed tests +# + + +@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"]) +def test_minus_minus(conn, pgtype): + cur = conn.cursor() + cast = f"::{pgtype}" if pgtype is not None else "" + cur.execute(f"select -%s{cast}", [-1]) + result = cur.fetchone()[0] + assert result == 1 + + +@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"]) +def test_minus_minus_quote(conn, pgtype): + cur = conn.cursor() + cast = f"::{pgtype}" if pgtype is not None else "" + cur.execute(sql.SQL("select -{}{}").format(sql.Literal(-1), sql.SQL(cast))) + result = cur.fetchone()[0] + assert result == 1 + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.numeric, wrapper) + obj = wrapper(1) + cur = conn.execute( + f"select %(obj){fmt_in.value} = 1, %(obj){fmt_in.value}", {"obj": obj} + ) + rec = cur.fetchone() + assert rec[0], rec[1] + + +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +def test_dump_wrapper_oid(wrapper): + wrapper = getattr(psycopg.types.numeric, wrapper) + base = wrapper.__mro__[1] + assert base in (int, float) + n = base(3.14) + assert str(wrapper(n)) == str(n) + assert repr(wrapper(n)) == f"{wrapper.__name__}({n})" + + +@pytest.mark.crdb("skip", reason="all types returned as bigint? TODOCRDB") +@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Oid Float4 Float8".split()) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_repr_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(psycopg.types.numeric, wrapper) + cur = conn.execute(f"select pg_typeof(%{fmt_in.value})::oid", [wrapper(0)]) + oid = cur.fetchone()[0] + assert oid == psycopg.postgres.types[wrapper.__name__.lower()].oid + + +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "typename", + "integer int2 int4 int8 float4 float8 numeric".split() + ["double precision"], +) +def test_oid_lookup(conn, typename, fmt_out): + dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out) + assert dumper.oid == conn.adapters.types[typename].oid + assert dumper.format == fmt_out diff --git a/tests/types/test_range.py b/tests/types/test_range.py new file mode 100644 index 0000000..1efd398 --- /dev/null +++ b/tests/types/test_range.py @@ -0,0 +1,677 @@ +import pickle +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg import pq, sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg.types import range as range_module +from psycopg.types.range import Range, RangeInfo, register_range + +from ..utils import eur +from ..fix_crdb import is_crdb, crdb_skip_message + +pytestmark = pytest.mark.crdb_skip("range") + +type2sub = { + "int4range": "int4", + "int8range": "int8", + "numrange": "numeric", + "daterange": "date", + "tsrange": "timestamp", + "tstzrange": "timestamptz", +} + +tzinfo = dt.timezone(dt.timedelta(hours=2)) + +samples = [ + ("int4range", None, None, "()"), + ("int4range", 10, 20, "[]"), + ("int4range", -(2**31), (2**31) - 1, "[)"), + ("int8range", None, None, "()"), + ("int8range", 10, 20, "[)"), + ("int8range", -(2**63), (2**63) - 1, "[)"), + ("numrange", Decimal(-100), Decimal("100.123"), "(]"), + ("numrange", Decimal(100), None, "()"), + ("numrange", None, Decimal(100), "()"), + ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"), + ( + "tsrange", + dt.datetime(2000, 1, 1, 00, 00), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999), + "[]", + ), + ( + "tstzrange", + dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + "()", + ), +] + +range_names = """ + int4range int8range numrange daterange tsrange tstzrange + """.split() + +range_classes = """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split() + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty(conn, pgtype, fmt_in): + r = Range(empty=True) # type: ignore[var-annotated] + cur = conn.execute(f"select 'empty'::{pgtype} = %{fmt_in.value}", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r = wrapper(empty=True) + cur = conn.execute(f"select 'empty' = %{fmt_in.value}", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize( + "fmt_in", + [ + PyFormat.AUTO, + PyFormat.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + PyFormat.BINARY, + marks=pytest.mark.xfail( + reason="can't dump an array of untypes binary range without cast" + ), + ), + ], +) +def test_dump_builtin_array(conn, pgtype, fmt_in): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.execute( + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %{fmt_in.value}", + ([r1, r2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.execute( + f"select array['empty'::{pgtype}, '(,)'::{pgtype}] " + f"= %{fmt_in.value}::{pgtype}[]", + ([r1, r2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(range_module, wrapper) + r1 = wrapper(empty=True) + r2 = wrapper(bounds="()") + cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in.value}""", ([r1, r2],)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_range(conn, pgtype, min, max, bounds, fmt_in): + r = Range(min, max, bounds) # type: ignore[var-annotated] + sub = type2sub[pgtype] + cur = conn.execute( + f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %{fmt_in.value}", + (min, max, bounds, r), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_empty(conn, pgtype, fmt_out): + r = Range(empty=True) # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone() + assert type(got) is Range + assert got == r + assert not got + assert got.isempty + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_inf(conn, pgtype, fmt_out): + r = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '(,)'::{pgtype}").fetchone() + assert type(got) is Range + assert got == r + assert got + assert not got.isempty + assert got.lower_inf + assert got.upper_inf + + +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_array(conn, pgtype, fmt_out): + r1 = Range(empty=True) # type: ignore[var-annotated] + r2 = Range(bounds="()") # type: ignore[var-annotated] + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone() + assert got == [r1, r2] + + +@pytest.mark.parametrize("pgtype, min, max, bounds", samples) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out): + r = Range(min, max, bounds) # type: ignore[var-annotated] + sub = type2sub[pgtype] + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)) + # normalise discrete ranges + if r.upper_inc and isinstance(r.upper, int): + bounds = "[)" if r.lower_inc else "()" + r = type(r)(r.lower, r.upper + 1, bounds) + assert cur.fetchone()[0] == r + + +@pytest.mark.parametrize( + "min, max, bounds", + [ + ("2000,1,1", "2001,1,1", "[)"), + ("2000,1,1", None, "[)"), + (None, "2001,1,1", "()"), + (None, None, "()"), + (None, None, "empty"), + ], +) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in(conn, min, max, bounds, format): + cur = conn.cursor() + cur.execute("create table copyrange (id serial primary key, r daterange)") + + if bounds != "empty": + min = dt.date(*map(int, min.split(","))) if min else None + max = dt.date(*map(int, max.split(","))) if max else None + r = Range[dt.date](min, max, bounds) + else: + r = Range(empty=True) + + try: + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.write_row([r]) + except e.ProtocolViolation: + if not min and not max and format == pq.Format.BINARY: + pytest.xfail("TODO: add annotation to dump ranges with no type info") + else: + raise + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.mark.parametrize("bounds", "() empty".split()) +@pytest.mark.parametrize("wrapper", range_classes) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_wrappers(conn, bounds, wrapper, format): + cur = conn.cursor() + cur.execute("create table copyrange (id serial primary key, r daterange)") + + cls = getattr(range_module, wrapper) + r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds) + + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.write_row([r]) + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.mark.parametrize("bounds", "() empty".split()) +@pytest.mark.parametrize("pgtype", range_names) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_set_type(conn, bounds, pgtype, format): + cur = conn.cursor() + cur.execute(f"create table copyrange (id serial primary key, r {pgtype})") + + if bounds == "empty": + r = Range(empty=True) # type: ignore[var-annotated] + else: + r = Range(None, None, bounds) + + with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy: + copy.set_types([pgtype]) + copy.write_row([r]) + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + +@pytest.fixture(scope="session") +def testrange(svcconn): + create_test_range(svcconn) + + +def create_test_range(conn): + if is_crdb(conn): + pytest.skip(crdb_skip_message("range")) + + conn.execute( + """ + create schema if not exists testschema; + + drop type if exists testrange cascade; + drop type if exists testschema.testrange cascade; + + create type testrange as range (subtype = text, collation = "C"); + create type testschema.testrange as range (subtype = float8); + """ + ) + + +fetch_cases = [ + ("testrange", "text"), + ("testschema.testrange", "float8"), + (sql.Identifier("testrange"), "text"), + (sql.Identifier("testschema", "testrange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testrange, name, subtype): + info = RangeInfo.fetch(conn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == conn.adapters.types[subtype].oid + + +def test_fetch_info_not_found(conn): + assert RangeInfo.fetch(conn, "nosuchrange") is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testrange, name, subtype): + info = await RangeInfo.fetch(aconn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == aconn.adapters.types[subtype].oid + + +@pytest.mark.asyncio +async def test_fetch_info_not_found_async(aconn): + assert await RangeInfo.fetch(aconn, "nosuchrange") is None + + +def test_dump_custom_empty(conn, testrange): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + + r = Range[str](empty=True) + cur = conn.execute("select 'empty'::testrange = %s", (r,)) + assert cur.fetchone()[0] is True + + +def test_dump_quoting(conn, testrange): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + cur = conn.cursor() + for i in range(1, 254): + cur.execute( + """ + select ascii(lower(%(r)s)) = %(low)s + and ascii(upper(%(r)s)) = %(up)s + """, + {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1}, + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_custom_empty(conn, testrange, fmt_out): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select 'empty'::testrange").fetchone() + assert isinstance(got, Range) + assert got.isempty + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_quoting(conn, testrange, fmt_out): + info = RangeInfo.fetch(conn, "testrange") + register_range(info, conn) + cur = conn.cursor(binary=fmt_out) + for i in range(1, 254): + cur.execute( + "select testrange(chr(%(low)s::int), chr(%(up)s::int))", + {"low": i, "up": i + 1}, + ) + got: Range[str] = cur.fetchone()[0] + assert isinstance(got, Range) + assert got.lower and ord(got.lower) == i + assert got.upper and ord(got.upper) == i + 1 + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_mixed_array_types(conn, fmt_out): + conn.execute("create table testmix (a daterange[], b tstzrange[])") + r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)") + r2 = Range( + dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), + dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc), + "[)", + ) + conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]]) + got = conn.execute("select * from testmix").fetchone() + assert got == ([r1], [r2]) + + +class TestRangeObject: + def test_noparam(self): + r = Range() # type: ignore[var-annotated] + + assert not r.isempty + assert r.lower is None + assert r.upper is None + assert r.lower_inf + assert r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_empty(self): + r = Range(empty=True) # type: ignore[var-annotated] + + assert r.isempty + assert r.lower is None + assert r.upper is None + assert not r.lower_inf + assert not r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_nobounds(self): + r = Range(10, 20) + assert r.lower == 10 + assert r.upper == 20 + assert not r.isempty + assert not r.lower_inf + assert not r.upper_inf + assert r.lower_inc + assert not r.upper_inc + + def test_bounds(self): + for bounds, lower_inc, upper_inc in [ + ("[)", True, False), + ("(]", False, True), + ("()", False, False), + ("[]", True, True), + ]: + r = Range(10, 20, bounds) + assert r.bounds == bounds + assert r.lower == 10 + assert r.upper == 20 + assert not r.isempty + assert not r.lower_inf + assert not r.upper_inf + assert r.lower_inc == lower_inc + assert r.upper_inc == upper_inc + + def test_keywords(self): + r = Range(upper=20) + r.lower is None + r.upper == 20 + assert not r.isempty + assert r.lower_inf + assert not r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + r = Range(lower=10, bounds="(]") + r.lower == 10 + r.upper is None + assert not r.isempty + assert not r.lower_inf + assert r.upper_inf + assert not r.lower_inc + assert not r.upper_inc + + def test_bad_bounds(self): + with pytest.raises(ValueError): + Range(bounds="(") + with pytest.raises(ValueError): + Range(bounds="[}") + + def test_in(self): + r = Range[int](empty=True) + assert 10 not in r + assert "x" not in r # type: ignore[operator] + + r = Range() + assert 10 in r + + r = Range(lower=10, bounds="[)") + assert 9 not in r + assert 10 in r + assert 11 in r + + r = Range(lower=10, bounds="()") + assert 9 not in r + assert 10 not in r + assert 11 in r + + r = Range(upper=20, bounds="()") + assert 19 in r + assert 20 not in r + assert 21 not in r + + r = Range(upper=20, bounds="(]") + assert 19 in r + assert 20 in r + assert 21 not in r + + r = Range(10, 20) + assert 9 not in r + assert 10 in r + assert 11 in r + assert 19 in r + assert 20 not in r + assert 21 not in r + + r = Range(10, 20, "(]") + assert 9 not in r + assert 10 not in r + assert 11 in r + assert 19 in r + assert 20 in r + assert 21 not in r + + r = Range(20, 10) + assert 9 not in r + assert 10 not in r + assert 11 not in r + assert 19 not in r + assert 20 not in r + assert 21 not in r + + def test_nonzero(self): + assert Range() + assert Range(10, 20) + assert not Range(empty=True) + + def test_eq_hash(self): + def assert_equal(r1, r2): + assert r1 == r2 + assert hash(r1) == hash(r2) + + assert_equal(Range(empty=True), Range(empty=True)) + assert_equal(Range(), Range()) + assert_equal(Range(10, None), Range(10, None)) + assert_equal(Range(10, 20), Range(10, 20)) + assert_equal(Range(10, 20), Range(10, 20, "[)")) + assert_equal(Range(10, 20, "[]"), Range(10, 20, "[]")) + + def assert_not_equal(r1, r2): + assert r1 != r2 + assert hash(r1) != hash(r2) + + assert_not_equal(Range(10, 20), Range(10, 21)) + assert_not_equal(Range(10, 20), Range(11, 20)) + assert_not_equal(Range(10, 20, "[)"), Range(10, 20, "[]")) + + def test_eq_wrong_type(self): + assert Range(10, 20) != () + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def test_lt_ordering(self): + assert Range(empty=True) < Range(0, 4) + assert not Range(1, 2) < Range(0, 4) + assert Range(0, 4) < Range(1, 2) + assert not Range(1, 2) < Range() + assert Range() < Range(1, 2) + assert not Range(1) < Range(upper=1) + assert not Range() < Range() + assert not Range(empty=True) < Range(empty=True) + assert not Range(1, 2) < Range(1, 2) + with pytest.raises(TypeError): + assert 1 < Range(1, 2) + with pytest.raises(TypeError): + assert not Range(1, 2) < 1 + + def test_gt_ordering(self): + assert not Range(empty=True) > Range(0, 4) + assert Range(1, 2) > Range(0, 4) + assert not Range(0, 4) > Range(1, 2) + assert Range(1, 2) > Range() + assert not Range() > Range(1, 2) + assert Range(1) > Range(upper=1) + assert not Range() > Range() + assert not Range(empty=True) > Range(empty=True) + assert not Range(1, 2) > Range(1, 2) + with pytest.raises(TypeError): + assert not 1 > Range(1, 2) + with pytest.raises(TypeError): + assert Range(1, 2) > 1 + + def test_le_ordering(self): + assert Range(empty=True) <= Range(0, 4) + assert not Range(1, 2) <= Range(0, 4) + assert Range(0, 4) <= Range(1, 2) + assert not Range(1, 2) <= Range() + assert Range() <= Range(1, 2) + assert not Range(1) <= Range(upper=1) + assert Range() <= Range() + assert Range(empty=True) <= Range(empty=True) + assert Range(1, 2) <= Range(1, 2) + with pytest.raises(TypeError): + assert 1 <= Range(1, 2) + with pytest.raises(TypeError): + assert not Range(1, 2) <= 1 + + def test_ge_ordering(self): + assert not Range(empty=True) >= Range(0, 4) + assert Range(1, 2) >= Range(0, 4) + assert not Range(0, 4) >= Range(1, 2) + assert Range(1, 2) >= Range() + assert not Range() >= Range(1, 2) + assert Range(1) >= Range(upper=1) + assert Range() >= Range() + assert Range(empty=True) >= Range(empty=True) + assert Range(1, 2) >= Range(1, 2) + with pytest.raises(TypeError): + assert not 1 >= Range(1, 2) + with pytest.raises(TypeError): + (Range(1, 2) >= 1) + + def test_pickling(self): + r = Range(0, 4) + assert pickle.loads(pickle.dumps(r)) == r + + def test_str(self): + """ + Range types should have a short and readable ``str`` implementation. + """ + expected = [ + "(0, 4)", + "[0, 4]", + "(0, 4]", + "[0, 4)", + "empty", + ] + results = [] + + for bounds in ("()", "[]", "(]", "[)"): + r = Range(0, 4, bounds=bounds) + results.append(str(r)) + + r = Range(empty=True) + results.append(str(r)) + assert results == expected + + def test_str_datetime(self): + """ + Date-Time ranges should return a human-readable string as well on + string conversion. + """ + tz = dt.timezone(dt.timedelta(hours=-5)) + r = Range( + dt.datetime(2010, 1, 1, tzinfo=tz), + dt.datetime(2011, 1, 1, tzinfo=tz), + ) + expected = "[2010-01-01 00:00:00-05:00, 2011-01-01 00:00:00-05:00)" + result = str(r) + assert result == expected + + def test_exclude_inf_bounds(self): + r = Range(None, 10, "[]") + assert r.lower is None + assert not r.lower_inc + assert r.bounds == "(]" + + r = Range(10, None, "[]") + assert r.upper is None + assert not r.upper_inc + assert r.bounds == "[)" + + r = Range(None, None, "[]") + assert r.lower is None + assert not r.lower_inc + assert r.upper is None + assert not r.upper_inc + assert r.bounds == "()" + + +def test_no_info_error(conn): + with pytest.raises(TypeError, match="range"): + register_range(None, conn) # type: ignore[arg-type] + + +@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"]) +def test_literal_invalid_name(conn, name): + conn.execute("set client_encoding to utf8") + conn.execute(f'create type "{name}" as range (subtype = text)') + info = RangeInfo.fetch(conn, f'"{name}"') + register_range(info, conn) + obj = Range("a", "z", "[]") + assert sql.Literal(obj).as_string(conn) == f"'[a,z]'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format(obj)) + assert cur.fetchone()[0] == obj diff --git a/tests/types/test_shapely.py b/tests/types/test_shapely.py new file mode 100644 index 0000000..0f7007e --- /dev/null +++ b/tests/types/test_shapely.py @@ -0,0 +1,152 @@ +import pytest + +import psycopg +from psycopg.pq import Format +from psycopg.types import TypeInfo +from psycopg.adapt import PyFormat + +pytest.importorskip("shapely") + +from shapely.geometry import Point, Polygon, MultiPolygon # noqa: E402 +from psycopg.types.shapely import register_shapely # noqa: E402 + +pytestmark = [ + pytest.mark.postgis, + pytest.mark.crdb("skip"), +] + +# real example, with CRS and "holes" +MULTIPOLYGON_GEOJSON = """ +{ + "type":"MultiPolygon", + "crs":{ + "type":"name", + "properties":{ + "name":"EPSG:3857" + } + }, + "coordinates":[ + [ + [ + [89574.61111389, 6894228.638802719], + [89576.815239808, 6894208.60747024], + [89576.904295401, 6894207.820852726], + [89577.99522641, 6894208.022080451], + [89577.961830563, 6894209.229446936], + [89589.227363031, 6894210.601454523], + [89594.615226386, 6894161.849595264], + [89600.314784314, 6894111.37846976], + [89651.187791607, 6894116.774968589], + [89648.49385993, 6894140.226914071], + [89642.92788539, 6894193.423936413], + [89639.721884055, 6894224.08372821], + [89589.283022777, 6894218.431048969], + [89588.192091767, 6894230.248628867], + [89574.61111389, 6894228.638802719] + ], + [ + [89610.344670435, 6894182.466199101], + [89625.985058891, 6894184.258949757], + [89629.547282597, 6894153.270030369], + [89613.918026089, 6894151.458993318], + [89610.344670435, 6894182.466199101] + ] + ] + ] +}""" + +SAMPLE_POINT_GEOJSON = '{"type":"Point","coordinates":[1.2, 3.4]}' + + +@pytest.fixture +def shapely_conn(conn, svcconn): + try: + with svcconn.transaction(): + svcconn.execute("create extension if not exists postgis") + except psycopg.Error as e: + pytest.skip(f"can't create extension postgis: {e}") + + info = TypeInfo.fetch(conn, "geometry") + assert info + register_shapely(info, conn) + return conn + + +def test_no_adapter(conn): + point = Point(1.2, 3.4) + with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"): + conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0] + + +def test_no_info_error(conn): + from psycopg.types.shapely import register_shapely + + with pytest.raises(TypeError, match="postgis.*extension"): + register_shapely(None, conn) # type: ignore[arg-type] + + +def test_with_adapter(shapely_conn): + SAMPLE_POINT = Point(1.2, 3.4) + SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)]) + + assert ( + shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0] + == "geometry" + ) + + assert ( + shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0] + == "geometry" + ) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", Format) +def test_write_read_shape(shapely_conn, fmt_in, fmt_out): + SAMPLE_POINT = Point(1.2, 3.4) + SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)]) + + with shapely_conn.cursor(binary=fmt_out) as cur: + cur.execute( + """ + create table sample_geoms( + id INTEGER PRIMARY KEY, + geom geometry + ) + """ + ) + cur.execute( + f"insert into sample_geoms(id, geom) VALUES(1, %{fmt_in})", + (SAMPLE_POINT,), + ) + cur.execute( + f"insert into sample_geoms(id, geom) VALUES(2, %{fmt_in})", + (SAMPLE_POLYGON,), + ) + + cur.execute("select geom from sample_geoms where id=1") + result = cur.fetchone()[0] + assert result == SAMPLE_POINT + + cur.execute("select geom from sample_geoms where id=2") + result = cur.fetchone()[0] + assert result == SAMPLE_POLYGON + + +@pytest.mark.parametrize("fmt_out", Format) +def test_match_geojson(shapely_conn, fmt_out): + SAMPLE_POINT = Point(1.2, 3.4) + with shapely_conn.cursor(binary=fmt_out) as cur: + cur.execute( + """ + select ST_GeomFromGeoJSON(%s) + """, + (SAMPLE_POINT_GEOJSON,), + ) + result = cur.fetchone()[0] + # clone the coordinates to have a list instead of a shapely wrapper + assert result.coords[:] == SAMPLE_POINT.coords[:] + # + cur.execute("select ST_GeomFromGeoJSON(%s)", (MULTIPOLYGON_GEOJSON,)) + result = cur.fetchone()[0] + assert isinstance(result, MultiPolygon) diff --git a/tests/types/test_string.py b/tests/types/test_string.py new file mode 100644 index 0000000..d23e5e0 --- /dev/null +++ b/tests/types/test_string.py @@ -0,0 +1,307 @@ +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.adapt import PyFormat +from psycopg import Binary + +from ..utils import eur +from ..fix_crdb import crdb_encoding, crdb_scs_off + +# +# tests with text +# + + +def crdb_bpchar(*args): + return pytest.param(*args, marks=pytest.mark.crdb("skip", reason="bpchar")) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_1char(conn, fmt_in): + cur = conn.cursor() + for i in range(1, 256): + cur.execute(f"select %{fmt_in.value} = chr(%s)", (chr(i), i)) + assert cur.fetchone()[0] is True, chr(i) + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_quote_1char(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") + + cur = conn.cursor() + query = sql.SQL("select {ch} = chr(%s)") + for i in range(1, 256): + if chr(i) == "%": + continue + cur.execute(query.format(ch=sql.Literal(chr(i))), (i,)) + assert cur.fetchone()[0] is True, chr(i) + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +@pytest.mark.crdb("skip", reason="can deal with 0 strings") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_zero(conn, fmt_in): + cur = conn.cursor() + s = "foo\x00bar" + with pytest.raises(psycopg.DataError): + cur.execute(f"select %{fmt_in.value}::text", (s,)) + + +def test_quote_zero(conn): + cur = conn.cursor() + s = "foo\x00bar" + with pytest.raises(psycopg.DataError): + cur.execute(sql.SQL("select {}").format(sql.Literal(s))) + + +# the only way to make this pass is to reduce %% -> % every time +# not only when there are query arguments +# see https://github.com/psycopg/psycopg2/issues/825 +@pytest.mark.xfail +def test_quote_percent(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%"))) + assert cur.fetchone()[0] == "%" + + cur.execute( + sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")), + (ord("%"),), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "typename", ["text", "varchar", "name", crdb_bpchar("bpchar"), '"char"'] +) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_1char(conn, typename, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(1, 256): + if typename == '"char"' and i > 127: + # for char > 128 the client receives only 194 or 195. + continue + + cur.execute(f"select chr(%s)::{typename}", (i,)) + res = cur.fetchone()[0] + assert res == chr(i) + + assert cur.pgresult.fformat(0) == fmt_out + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize( + "encoding", ["utf8", crdb_encoding("latin9"), crdb_encoding("sql_ascii")] +) +def test_dump_enc(conn, fmt_in, encoding): + cur = conn.cursor() + + conn.execute(f"set client_encoding to {encoding}") + (res,) = cur.execute(f"select ascii(%{fmt_in.value})", (eur,)).fetchone() + assert res == ord(eur) + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_badenc(conn, fmt_in): + cur = conn.cursor() + + conn.execute("set client_encoding to latin1") + with pytest.raises(UnicodeEncodeError): + cur.execute(f"select %{fmt_in.value}::bytea", (eur,)) + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_utf8_badenc(conn, fmt_in): + cur = conn.cursor() + + conn.execute("set client_encoding to utf8") + with pytest.raises(UnicodeEncodeError): + cur.execute(f"select %{fmt_in.value}", ("\uddf8",)) + + +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_enum(conn, fmt_in): + from enum import Enum + + class MyEnum(str, Enum): + foo = "foo" + bar = "bar" + + cur = conn.cursor() + cur.execute("create type myenum as enum ('foo', 'bar')") + cur.execute("create table with_enum (e myenum)") + cur.execute(f"insert into with_enum (e) values (%{fmt_in.value})", (MyEnum.foo,)) + (res,) = cur.execute("select e from with_enum").fetchone() + assert res == "foo" + + +@pytest.mark.crdb("skip") +@pytest.mark.parametrize("fmt_in", [PyFormat.AUTO, PyFormat.TEXT]) +def test_dump_text_oid(conn, fmt_in): + conn.autocommit = True + + with pytest.raises(e.IndeterminateDatatype): + conn.execute(f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"]) + conn.adapters.register_dumper(str, psycopg.types.string.StrDumper) + cur = conn.execute( + f"select concat(%{fmt_in.value}, %{fmt_in.value})", ["foo", "bar"] + ) + assert cur.fetchone()[0] == "foobar" + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_enc(conn, typename, encoding, fmt_out): + cur = conn.cursor(binary=fmt_out) + + conn.execute(f"set client_encoding to {encoding}") + (res,) = cur.execute(f"select chr(%s)::{typename}", [ord(eur)]).fetchone() + assert res == eur + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + + assert res == eur + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_badenc(conn, typename, fmt_out): + conn.autocommit = True + cur = conn.cursor(binary=fmt_out) + + conn.execute("set client_encoding to latin1") + with pytest.raises(psycopg.DataError): + cur.execute(f"select chr(%s)::{typename}", [ord(eur)]) + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + with pytest.raises(psycopg.DataError): + copy.read_row() + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) +def test_load_ascii(conn, typename, fmt_out): + cur = conn.cursor(binary=fmt_out) + + conn.execute("set client_encoding to sql_ascii") + cur.execute(f"select chr(%s)::{typename}", [ord(eur)]) + assert cur.fetchone()[0] == eur.encode() + + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( + ord(eur), sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types([typename]) + (res,) = copy.read_row() + + assert res == eur.encode() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name", crdb_bpchar("bpchar")]) +def test_text_array(conn, typename, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + a = list(map(chr, range(1, 256))) + [eur] + + (res,) = cur.execute(f"select %{fmt_in.value}::{typename}[]", (a,)).fetchone() + assert res == a + + +@pytest.mark.crdb_skip("encoding") +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_text_array_ascii(conn, fmt_in, fmt_out): + conn.execute("set client_encoding to sql_ascii") + cur = conn.cursor(binary=fmt_out) + a = list(map(chr, range(1, 256))) + [eur] + exp = [s.encode() for s in a] + (res,) = cur.execute(f"select %{fmt_in.value}::text[]", (a,)).fetchone() + assert res == exp + + +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("typename", ["text", "varchar", "name"]) +def test_oid_lookup(conn, typename, fmt_out): + dumper = conn.adapters.get_dumper_by_oid(conn.adapters.types[typename].oid, fmt_out) + assert dumper.oid == conn.adapters.types[typename].oid + assert dumper.format == fmt_out + + +# +# tests with bytea +# + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary]) +def test_dump_1byte(conn, fmt_in, pytype): + cur = conn.cursor() + for i in range(0, 256): + obj = pytype(bytes([i])) + cur.execute(f"select %{fmt_in.value} = set_byte('x', 0, %s)", (obj, i)) + assert cur.fetchone()[0] is True, i + + cur.execute(f"select %{fmt_in.value} = array[set_byte('x', 0, %s)]", ([obj], i)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary]) +def test_quote_1byte(conn, scs, pytype): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") + + cur = conn.cursor() + query = sql.SQL("select {ch} = set_byte('x', 0, %s)") + for i in range(0, 256): + obj = pytype(bytes([i])) + cur.execute(query.format(ch=sql.Literal(obj)), (i,)) + assert cur.fetchone()[0] is True, i + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_1byte(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + for i in range(0, 256): + cur.execute("select set_byte('x', 0, %s)", (i,)) + val = cur.fetchone()[0] + assert val == bytes([i]) + + assert isinstance(val, bytes) + assert cur.pgresult.fformat(0) == fmt_out + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_bytea_array(conn, fmt_in, fmt_out): + cur = conn.cursor(binary=fmt_out) + a = [bytes(range(0, 256))] + (res,) = cur.execute(f"select %{fmt_in.value}::bytea[]", (a,)).fetchone() + assert res == a diff --git a/tests/types/test_uuid.py b/tests/types/test_uuid.py new file mode 100644 index 0000000..f86f066 --- /dev/null +++ b/tests/types/test_uuid.py @@ -0,0 +1,56 @@ +import sys +from uuid import UUID +import subprocess as sp + +import pytest + +from psycopg import pq +from psycopg import sql +from psycopg.adapt import PyFormat + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_uuid_dump(conn, fmt_in): + val = "12345678123456781234567812345679" + cur = conn.cursor() + cur.execute(f"select %{fmt_in.value} = %s::uuid", (UUID(val), val)) + assert cur.fetchone()[0] is True + + +@pytest.mark.crdb_skip("copy") +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_uuid_load(conn, fmt_out): + cur = conn.cursor(binary=fmt_out) + val = "12345678123456781234567812345679" + cur.execute("select %s::uuid", (val,)) + assert cur.fetchone()[0] == UUID(val) + + stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format( + val, sql.SQL(fmt_out.name) + ) + with cur.copy(stmt) as copy: + copy.set_types(["uuid"]) + (res,) = copy.read_row() + + assert res == UUID(val) + + +@pytest.mark.slow +@pytest.mark.subprocess +def test_lazy_load(dsn): + script = f"""\ +import sys +import psycopg + +assert 'uuid' not in sys.modules + +conn = psycopg.connect({dsn!r}) +with conn.cursor() as cur: + cur.execute("select repeat('1', 32)::uuid") + cur.fetchone() + +conn.close() +assert 'uuid' in sys.modules +""" + + sp.check_call([sys.executable, "-c", script]) diff --git a/tests/typing_example.py b/tests/typing_example.py new file mode 100644 index 0000000..a26ca49 --- /dev/null +++ b/tests/typing_example.py @@ -0,0 +1,176 @@ +# flake8: builtins=reveal_type + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +from psycopg import Connection, Cursor, ServerCursor, connect, rows +from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor + + +def int_row_factory( + cursor: Union[Cursor[Any], AsyncCursor[Any]] +) -> Callable[[Sequence[int]], int]: + return lambda values: values[0] if values else 42 + + +@dataclass +class Person: + name: str + address: str + + @classmethod + def row_factory( + cls, cursor: Union[Cursor[Any], AsyncCursor[Any]] + ) -> Callable[[Sequence[str]], Person]: + def mkrow(values: Sequence[str]) -> Person: + name, address = values + return cls(name, address) + + return mkrow + + +def kwargsf(*, foo: int, bar: int, baz: int) -> int: + return 42 + + +def argsf(foo: int, bar: int, baz: int) -> float: + return 42.0 + + +def check_row_factory_cursor() -> None: + """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case.""" + conn = connect() + + cur1: Cursor[Any] + cur1 = conn.cursor() + r1: Optional[Any] + r1 = cur1.fetchone() + r1 is not None + + cur2: Cursor[int] + r2: Optional[int] + with conn.cursor(row_factory=int_row_factory) as cur2: + cur2.execute("select 1") + r2 = cur2.fetchone() + r2 and r2 > 0 + + cur3: ServerCursor[Person] + persons: Sequence[Person] + with conn.cursor(name="s", row_factory=Person.row_factory) as cur3: + cur3.execute("select * from persons where name like 'al%'") + persons = cur3.fetchall() + persons[0].address + + +async def async_check_row_factory_cursor() -> None: + """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case.""" + conn = await AsyncConnection.connect() + + cur1: AsyncCursor[Any] + cur1 = conn.cursor() + r1: Optional[Any] + r1 = await cur1.fetchone() + r1 is not None + + cur2: AsyncCursor[int] + r2: Optional[int] + async with conn.cursor(row_factory=int_row_factory) as cur2: + await cur2.execute("select 1") + r2 = await cur2.fetchone() + r2 and r2 > 0 + + cur3: AsyncServerCursor[Person] + persons: Sequence[Person] + async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3: + await cur3.execute("select * from persons where name like 'al%'") + persons = await cur3.fetchall() + persons[0].address + + +def check_row_factory_connection() -> None: + """Type-check connect(..., row_factory=<MyRowFactory>) or + Connection.row_factory cases. + """ + conn1: Connection[int] + cur1: Cursor[int] + r1: Optional[int] + conn1 = connect(row_factory=int_row_factory) + cur1 = conn1.execute("select 1") + r1 = cur1.fetchone() + r1 != 0 + with conn1.cursor() as cur1: + cur1.execute("select 2") + + conn2: Connection[Person] + cur2: Cursor[Person] + r2: Optional[Person] + conn2 = connect(row_factory=Person.row_factory) + cur2 = conn2.execute("select * from persons") + r2 = cur2.fetchone() + r2 and r2.name + with conn2.cursor() as cur2: + cur2.execute("select 2") + + cur3: Cursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = connect() + cur3 = conn3.execute("select 3") + with conn3.cursor() as cur3: + cur3.execute("select 42") + r3 = cur3.fetchone() + r3 and len(r3) + + +async def async_check_row_factory_connection() -> None: + """Type-check connect(..., row_factory=<MyRowFactory>) or + Connection.row_factory cases. + """ + conn1: AsyncConnection[int] + cur1: AsyncCursor[int] + r1: Optional[int] + conn1 = await AsyncConnection.connect(row_factory=int_row_factory) + cur1 = await conn1.execute("select 1") + r1 = await cur1.fetchone() + r1 != 0 + async with conn1.cursor() as cur1: + await cur1.execute("select 2") + + conn2: AsyncConnection[Person] + cur2: AsyncCursor[Person] + r2: Optional[Person] + conn2 = await AsyncConnection.connect(row_factory=Person.row_factory) + cur2 = await conn2.execute("select * from persons") + r2 = await cur2.fetchone() + r2 and r2.name + async with conn2.cursor() as cur2: + await cur2.execute("select 2") + + cur3: AsyncCursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = await AsyncConnection.connect() + cur3 = await conn3.execute("select 3") + async with conn3.cursor() as cur3: + await cur3.execute("select 42") + r3 = await cur3.fetchone() + r3 and len(r3) + + +def check_row_factories() -> None: + conn1 = connect(row_factory=rows.tuple_row) + v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0] + + conn2 = connect(row_factory=rows.dict_row) + v2: Dict[str, Any] = conn2.execute("").fetchall()[0] + + conn3 = connect(row_factory=rows.class_row(Person)) + v3: Person = conn3.execute("").fetchall()[0] + + conn4 = connect(row_factory=rows.args_row(argsf)) + v4: float = conn4.execute("").fetchall()[0] + + conn5 = connect(row_factory=rows.kwargs_row(kwargsf)) + v5: int = conn5.execute("").fetchall()[0] + + v1, v2, v3, v4, v5 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..871f65d --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,179 @@ +import gc +import re +import sys +import operator +from typing import Callable, Optional, Tuple + +import pytest + +eur = "\u20ac" + + +def check_libpq_version(got, want): + """ + Verify if the libpq version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.libpq(">= 12") + + and skips the test if the requested version doesn't match what's loaded. + """ + return check_version(got, want, "libpq", postgres_rule=True) + + +def check_postgres_version(got, want): + """ + Verify if the server version is a version accepted. + + This function is called on the tests marked with something like:: + + @pytest.mark.pg(">= 12") + + and skips the test if the server version doesn't match what expected. + """ + return check_version(got, want, "PostgreSQL", postgres_rule=True) + + +def check_version(got, want, whose_version, postgres_rule=True): + pred = VersionCheck.parse(want, postgres_rule=postgres_rule) + pred.whose = whose_version + return pred.get_skip_message(got) + + +class VersionCheck: + """ + Helper to compare a version number with a test spec. + """ + + def __init__( + self, + *, + skip: bool = False, + op: Optional[str] = None, + version_tuple: Tuple[int, ...] = (), + whose: str = "(wanted)", + postgres_rule: bool = False, + ): + self.skip = skip + self.op = op or "==" + self.version_tuple = version_tuple + self.whose = whose + # Treat 10.1 as 10.0.1 + self.postgres_rule = postgres_rule + + @classmethod + def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck": + # Parse a spec like "> 9.6", "skip < 21.2.0" + m = re.match( + r"""(?ix) + ^\s* (skip|only)? + \s* (==|!=|>=|<=|>|<)? + \s* (?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)? + \s* $ + """, + spec, + ) + if m is None: + pytest.fail(f"bad wanted version spec: {spec}") + + skip = (m.group(1) or "only").lower() == "skip" + op = m.group(2) + version_tuple = tuple(int(n) for n in m.groups()[2:] if n) + + return cls( + skip=skip, op=op, version_tuple=version_tuple, postgres_rule=postgres_rule + ) + + def get_skip_message(self, version: Optional[int]) -> Optional[str]: + got_tuple = self._parse_int_version(version) + + msg: Optional[str] = None + if self.skip: + if got_tuple: + if not self.version_tuple: + msg = f"skip on {self.whose}" + elif self._match_version(got_tuple): + msg = ( + f"skip on {self.whose} {self.op}" + f" {'.'.join(map(str, self.version_tuple))}" + ) + else: + if not got_tuple: + msg = f"only for {self.whose}" + elif not self._match_version(got_tuple): + if self.version_tuple: + msg = ( + f"only for {self.whose} {self.op}" + f" {'.'.join(map(str, self.version_tuple))}" + ) + else: + msg = f"only for {self.whose}" + + return msg + + _OP_NAMES = {">=": "ge", "<=": "le", ">": "gt", "<": "lt", "==": "eq", "!=": "ne"} + + def _match_version(self, got_tuple: Tuple[int, ...]) -> bool: + if not self.version_tuple: + return True + + version_tuple = self.version_tuple + if self.postgres_rule and version_tuple and version_tuple[0] >= 10: + assert len(version_tuple) <= 2 + version_tuple = version_tuple[:1] + (0,) + version_tuple[1:] + + op: Callable[[Tuple[int, ...], Tuple[int, ...]], bool] + op = getattr(operator, self._OP_NAMES[self.op]) + return op(got_tuple, version_tuple) + + def _parse_int_version(self, version: Optional[int]) -> Tuple[int, ...]: + if version is None: + return () + version, ver_fix = divmod(version, 100) + ver_maj, ver_min = divmod(version, 100) + return (ver_maj, ver_min, ver_fix) + + +def gc_collect(): + """ + gc.collect(), but more insisting. + """ + for i in range(3): + gc.collect() + + +NO_COUNT_TYPES: Tuple[type, ...] = () + +if sys.version_info[:2] == (3, 10): + # On my laptop there are occasional creations of a single one of these objects + # with empty content, which might be some Decimal caching. + # Keeping the guard as strict as possible, to be extended if other types + # or versions are necessary. + try: + from _contextvars import Context # type: ignore + except ImportError: + pass + else: + NO_COUNT_TYPES += (Context,) + + +def gc_count() -> int: + """ + len(gc.get_objects()), with subtleties. + """ + if not NO_COUNT_TYPES: + return len(gc.get_objects()) + + # Note: not using a list comprehension because it pollutes the objects list. + rv = 0 + for obj in gc.get_objects(): + if isinstance(obj, NO_COUNT_TYPES): + continue + rv += 1 + + return rv + + +async def alist(it): + return [i async for i in it] |