diff options
Diffstat (limited to 'tests/crdb')
-rw-r--r-- | tests/crdb/__init__.py | 0 | ||||
-rw-r--r-- | tests/crdb/test_adapt.py | 78 | ||||
-rw-r--r-- | tests/crdb/test_connection.py | 86 | ||||
-rw-r--r-- | tests/crdb/test_connection_async.py | 85 | ||||
-rw-r--r-- | tests/crdb/test_conninfo.py | 21 | ||||
-rw-r--r-- | tests/crdb/test_copy.py | 233 | ||||
-rw-r--r-- | tests/crdb/test_copy_async.py | 235 | ||||
-rw-r--r-- | tests/crdb/test_cursor.py | 65 | ||||
-rw-r--r-- | tests/crdb/test_cursor_async.py | 61 | ||||
-rw-r--r-- | tests/crdb/test_no_crdb.py | 34 | ||||
-rw-r--r-- | tests/crdb/test_typing.py | 49 |
11 files changed, 947 insertions, 0 deletions
diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/crdb/__init__.py diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py new file mode 100644 index 0000000..ce5bacf --- /dev/null +++ b/tests/crdb/test_adapt.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +import pytest + +from psycopg.crdb import adapters, CrdbConnection + +from psycopg.adapt import PyFormat, Transformer +from psycopg.types.array import ListDumper +from psycopg.postgres import types as builtins + +from ..test_adapt import MyStr, make_dumper, make_bin_dumper +from ..test_adapt import make_loader, make_bin_loader + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_return_untyped(conn, fmt_in): + # Analyze and check for changes using strings in untyped/typed contexts + cur = conn.cursor() + # Currently string are passed as text oid to CockroachDB, unlike Postgres, + # to which strings are passed as unknown. This is because CRDB doesn't + # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises + # an error. However, unlike PostgreSQL, text can be cast to any other type. + cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10]) + assert cur.fetchone() == ("hello", 10) + + cur.execute("create table testjson(data jsonb)") + cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"]) + assert cur.execute("select data from testjson").fetchone() == ({},) + + +def test_str_list_dumper_text(conn): + t = Transformer(conn) + dstr = t.get_dumper([""], PyFormat.TEXT) + assert isinstance(dstr, ListDumper) + assert dstr.oid == builtins["text"].array_oid + assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid + + +@pytest.fixture +def crdb_adapters(): + """Restore the crdb adapters after a test has changed them.""" + dumpers = deepcopy(adapters._dumpers) + dumpers_by_oid = deepcopy(adapters._dumpers_by_oid) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._dumpers_by_oid = dumpers_by_oid + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) + + +def test_dump_global_ctx(dsn, crdb_adapters, pgconn): + adapters.register_dumper(MyStr, make_bin_dumper("gb")) + adapters.register_dumper(MyStr, make_dumper("gt")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + + +def test_load_global_ctx(dsn, crdb_adapters): + adapters.register_loader("text", make_loader("gt")) + adapters.register_loader("text", make_bin_loader("gb")) + with CrdbConnection.connect(dsn) as conn: + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py new file mode 100644 index 0000000..b2a69ef --- /dev/null +++ b/tests/crdb/test_connection.py @@ -0,0 +1,86 @@ +import time +import threading + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb + + +def test_is_crdb(conn): + assert CrdbConnection.is_crdb(conn) + assert CrdbConnection.is_crdb(conn.pgconn) + + +def test_connect(dsn): + with CrdbConnection.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + with psycopg.crdb.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + + +def test_xid(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +def test_tpc_begin(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_begin("foo") + + +def test_tpc_recover(dsn): + with CrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.tpc_recover() + + +@pytest.mark.slow +def test_broken_connection(conn): + cur = conn.cursor() + (session_id,) = cur.execute("select session_id from [show session_id]").fetchone() + with pytest.raises(psycopg.DatabaseError): + cur.execute("cancel session %s", [session_id]) + assert conn.closed + + +@pytest.mark.slow +def test_broken(conn): + (session_id,) = conn.execute("show session_id").fetchone() + with pytest.raises(psycopg.OperationalError): + conn.execute("cancel session %s", [session_id]) + + assert conn.closed + assert conn.broken + conn.close() + assert conn.closed + assert conn.broken + + +@pytest.mark.slow +def test_identify_closure(conn_cls, dsn): + with conn_cls.connect(dsn, autocommit=True) as conn: + with conn_cls.connect(dsn, autocommit=True) as conn2: + (session_id,) = conn.execute("show session_id").fetchone() + + def closer(): + time.sleep(0.2) + conn2.execute("cancel session %s", [session_id]) + + t = threading.Thread(target=closer) + t.start() + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + # CRDB seems to take not less than 1s + assert 0.2 < dt < 2 + finally: + t.join() diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py new file mode 100644 index 0000000..b568e42 --- /dev/null +++ b/tests/crdb/test_connection_async.py @@ -0,0 +1,85 @@ +import time +import asyncio + +import psycopg.crdb +from psycopg import errors as e +from psycopg.crdb import AsyncCrdbConnection +from psycopg._compat import create_task + +import pytest + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +async def test_is_crdb(aconn): + assert AsyncCrdbConnection.is_crdb(aconn) + assert AsyncCrdbConnection.is_crdb(aconn.pgconn) + + +async def test_connect(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) + + +async def test_xid(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + conn.xid(1, "gtrid", "bqual") + + +async def test_tpc_begin(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_begin("foo") + + +async def test_tpc_recover(dsn): + async with await AsyncCrdbConnection.connect(dsn) as conn: + with pytest.raises(e.NotSupportedError): + await conn.tpc_recover() + + +@pytest.mark.slow +async def test_broken_connection(aconn): + cur = aconn.cursor() + await cur.execute("select session_id from [show session_id]") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.DatabaseError): + await cur.execute("cancel session %s", [session_id]) + assert aconn.closed + + +@pytest.mark.slow +async def test_broken(aconn): + cur = await aconn.execute("show session_id") + (session_id,) = await cur.fetchone() + with pytest.raises(psycopg.OperationalError): + await aconn.execute("cancel session %s", [session_id]) + + assert aconn.closed + assert aconn.broken + await aconn.close() + assert aconn.closed + assert aconn.broken + + +@pytest.mark.slow +async def test_identify_closure(aconn_cls, dsn): + async with await aconn_cls.connect(dsn) as conn: + async with await aconn_cls.connect(dsn) as conn2: + cur = await conn.execute("show session_id") + (session_id,) = await cur.fetchone() + + async def closer(): + await asyncio.sleep(0.2) + await conn2.execute("cancel session %s", [session_id]) + + t = create_task(closer()) + t0 = time.time() + try: + with pytest.raises(psycopg.OperationalError): + await conn.execute("select pg_sleep(3.0)") + dt = time.time() - t0 + assert 0.2 < dt < 2 + finally: + await asyncio.gather(t) diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py new file mode 100644 index 0000000..274a0c0 --- /dev/null +++ b/tests/crdb/test_conninfo.py @@ -0,0 +1,21 @@ +import pytest + +pytestmark = pytest.mark.crdb + + +def test_vendor(conn): + assert conn.info.vendor == "CockroachDB" + + +def test_server_version(conn): + assert conn.info.server_version > 200000 + + +@pytest.mark.crdb("< 22") +def test_backend_pid_pre_22(conn): + assert conn.info.backend_pid == 0 + + +@pytest.mark.crdb(">= 22") +def test_backend_pid(conn): + assert conn.info.backend_pid > 0 diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py new file mode 100644 index 0000000..b7d26aa --- /dev/null +++ b/tests/crdb/test_copy.py @@ -0,0 +1,233 @@ +import pytest +import string +from random import randrange, choice + +from psycopg import sql, errors as e +from psycopg.pq import Format +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import ensure_table, sample_records +from ..test_copy import sample_tabledef as sample_tabledef_pg + +# CRDB int/serial are int8 +sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") + +pytestmark = pytest.mark.crdb + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +def test_copy_in_buffers(conn, format, buffer): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.write(globals()[buffer]) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +def test_copy_in_buffers_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_str(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text.decode()) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +def test_copy_in_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + with cur.copy("copy copy_in from stdin with binary") as copy: + copy.write(sample_text.decode()) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_empty(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert conn.info.transaction_status == conn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +def test_copy_big_size_record(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write_row([data]) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +@pytest.mark.slow +def test_copy_big_size_block(conn): + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + with cur.copy("copy copy_in (data) from stdin") as copy: + copy.write(copy_data) + + cur.execute("select data from copy_in limit 1") + assert cur.fetchone()[0] == data + + +def test_copy_in_buffers_with_pg_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + copy.write(sample_text) + + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_set_types(conn, format): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + + with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + copy.write_row(row) + + data = cur.execute("select * from copy_in order by 1").fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +def test_copy_in_records_binary(conn, format): + cur = conn.cursor() + ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy: + for row in sample_records: + copy.write_row((None, row[2])) + + data = cur.execute("select col2, data from copy_in order by 2").fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +def test_copy_in_buffers_with_py_error(conn): + cur = conn.cursor() + ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + with cur.copy("copy copy_in from stdin") as copy: + copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert conn.info.transaction_status == conn.TransactionStatus.INERROR + + +def test_copy_in_allchars(conn): + cur = conn.cursor() + ensure_table(cur, "col1 int primary key, col2 int, data text") + + with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + copy.write_row((i, None, chr(i))) + copy.write_row((ord(eur), None, eur)) + + data = cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ).fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + def work(): + with conn_cls.connect(dsn) as conn: + with conn.cursor(binary=fmt) as cur: + cur.execute(faker.drop_stmt) + cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + copy.write_row(row) + + cur.execute(faker.select_stmt) + recs = cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def copyopt(format): + return "with binary" if format == Format.BINARY else "" diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py new file mode 100644 index 0000000..d5fbf50 --- /dev/null +++ b/tests/crdb/test_copy_async.py @@ -0,0 +1,235 @@ +import pytest +import string +from random import randrange, choice + +from psycopg.pq import Format +from psycopg import sql, errors as e +from psycopg.adapt import PyFormat +from psycopg.types.numeric import Int4 + +from ..utils import eur, gc_collect, gc_count +from ..test_copy import sample_text, sample_binary # noqa +from ..test_copy import sample_records +from ..test_copy_async import ensure_table +from .test_copy import sample_tabledef, copyopt + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.parametrize( + "format, buffer", + [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559") +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin with binary") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +@pytest.mark.slow +async def test_copy_big_size_block(aconn): + cur = aconn.cursor() + await ensure_table(cur, "id serial primary key, data text") + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert (await cur.fetchone())[0] == data + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int4, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin {copyopt(format)}" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select col2, data from copy_in order by 2") + data = await cur.fetchall() + assert data == [(None, "hello"), (None, "world")] + + +@pytest.mark.crdb_skip("copy canceled") +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, "col1 int primary key, col2 int, data text") + + async with cur.copy("copy copy_in from stdin") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.crdb_skip("copy array") +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin {}").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL("with binary" if fmt else ""), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py new file mode 100644 index 0000000..991b084 --- /dev/null +++ b/tests/crdb/test_cursor.py @@ -0,0 +1,65 @@ +import json +import threading +from uuid import uuid4 +from queue import Queue +from typing import Any + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.crdb + + +@pytest.fixture +def testfeed(svcconn): + name = f"test_feed_{str(uuid4()).replace('-', '')}" + svcconn.execute("set cluster setting kv.rangefeed.enabled to true") + svcconn.execute(f"create table {name} (id serial primary key, data text)") + yield name + svcconn.execute(f"drop table {name}") + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): + conn.autocommit = True + q: "Queue[Any]" = Queue() + + def worker(): + try: + with conn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + for row in cur.stream(f"experimental changefeed for {testfeed}"): + q.put(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put(None) + except Exception as ex: + q.put(ex) + + t = threading.Thread(target=worker) + t.start() + + cur = conn.cursor() + cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = cur.fetchone() + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = q.get(timeout=1) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = cur.fetchone() + cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert q.get(timeout=1) is None + t.join() diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py new file mode 100644 index 0000000..229295d --- /dev/null +++ b/tests/crdb/test_cursor_async.py @@ -0,0 +1,61 @@ +import json +import asyncio +from typing import Any +from asyncio.queues import Queue + +import pytest +from psycopg import pq, errors as e +from psycopg.rows import namedtuple_row +from psycopg._compat import create_task + +from .test_cursor import testfeed + +testfeed # fixture + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +@pytest.mark.slow +@pytest.mark.parametrize("fmt_out", pq.Format) +async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): + await aconn.set_autocommit(True) + q: "Queue[Any]" = Queue() + + async def worker(): + try: + async with await aconn_cls.connect(dsn, autocommit=True) as conn: + cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) + try: + async for row in cur.stream( + f"experimental changefeed for {testfeed}" + ): + q.put_nowait(row) + except e.QueryCanceled: + assert conn.info.transaction_status == conn.TransactionStatus.IDLE + q.put_nowait(None) + except Exception as ex: + q.put_nowait(ex) + + t = create_task(worker()) + + cur = aconn.cursor() + await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") + (key,) = await cur.fetchone() + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} + + await cur.execute(f"delete from {testfeed} where id = %s", [key]) + row = await asyncio.wait_for(q.get(), 1.0) + assert row.table == testfeed + assert json.loads(row.key) == [key] + assert json.loads(row.value)["after"] is None + + await cur.execute("select query_id from [show statements] where query !~ 'show'") + (qid,) = await cur.fetchone() + await cur.execute("cancel query %s", [qid]) + assert cur.statusmessage == "CANCEL QUERIES 1" + + assert await asyncio.wait_for(q.get(), 1.0) is None + await asyncio.gather(t) diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py new file mode 100644 index 0000000..df43f3b --- /dev/null +++ b/tests/crdb/test_no_crdb.py @@ -0,0 +1,34 @@ +from psycopg.pq import TransactionStatus +from psycopg.crdb import CrdbConnection + +import pytest + +pytestmark = pytest.mark.crdb("skip") + + +def test_is_crdb(conn): + assert not CrdbConnection.is_crdb(conn) + assert not CrdbConnection.is_crdb(conn.pgconn) + + +def test_tpc_on_pg_connection(conn, tpc): + xid = conn.xid(1, "gtrid", "bqual") + assert conn.info.transaction_status == TransactionStatus.IDLE + + conn.tpc_begin(xid) + assert conn.info.transaction_status == TransactionStatus.INTRANS + + cur = conn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit')") + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 0 + + conn.tpc_prepare() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 1 + assert tpc.count_test_records() == 0 + + conn.tpc_commit() + assert conn.info.transaction_status == TransactionStatus.IDLE + assert tpc.count_xacts() == 0 + assert tpc.count_test_records() == 1 diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py new file mode 100644 index 0000000..2cff0a7 --- /dev/null +++ b/tests/crdb/test_typing.py @@ -0,0 +1,49 @@ +import pytest + +from ..test_typing import _test_reveal + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.crdb.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect()", + "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal_crdb(stmts, type, mypy) + + +def _test_reveal_crdb(stmts, type, mypy): + stmts = f"""\ +import psycopg.crdb +{stmts} +""" + _test_reveal(stmts, type, mypy) |