summaryrefslogtreecommitdiffstats
path: root/tests/crdb
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crdb')
-rw-r--r--tests/crdb/__init__.py0
-rw-r--r--tests/crdb/test_adapt.py78
-rw-r--r--tests/crdb/test_connection.py86
-rw-r--r--tests/crdb/test_connection_async.py85
-rw-r--r--tests/crdb/test_conninfo.py21
-rw-r--r--tests/crdb/test_copy.py233
-rw-r--r--tests/crdb/test_copy_async.py235
-rw-r--r--tests/crdb/test_cursor.py65
-rw-r--r--tests/crdb/test_cursor_async.py61
-rw-r--r--tests/crdb/test_no_crdb.py34
-rw-r--r--tests/crdb/test_typing.py49
11 files changed, 947 insertions, 0 deletions
diff --git a/tests/crdb/__init__.py b/tests/crdb/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/crdb/__init__.py
diff --git a/tests/crdb/test_adapt.py b/tests/crdb/test_adapt.py
new file mode 100644
index 0000000..ce5bacf
--- /dev/null
+++ b/tests/crdb/test_adapt.py
@@ -0,0 +1,78 @@
+from copy import deepcopy
+
+import pytest
+
+from psycopg.crdb import adapters, CrdbConnection
+
+from psycopg.adapt import PyFormat, Transformer
+from psycopg.types.array import ListDumper
+from psycopg.postgres import types as builtins
+
+from ..test_adapt import MyStr, make_dumper, make_bin_dumper
+from ..test_adapt import make_loader, make_bin_loader
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as text oid to CockroachDB, unlike Postgres,
+ # to which strings are passed as unknown. This is because CRDB doesn't
+ # allow the unknown oid to be emitted; execute("SELECT %s", ["str"]) raises
+ # an error. However, unlike PostgreSQL, text can be cast to any other type.
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+
+
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+@pytest.fixture
+def crdb_adapters():
+ """Restore the crdb adapters after a test has changed them."""
+ dumpers = deepcopy(adapters._dumpers)
+ dumpers_by_oid = deepcopy(adapters._dumpers_by_oid)
+ loaders = deepcopy(adapters._loaders)
+ types = list(adapters.types)
+
+ yield None
+
+ adapters._dumpers = dumpers
+ adapters._dumpers_by_oid = dumpers_by_oid
+ adapters._loaders = loaders
+ adapters.types.clear()
+ for t in types:
+ adapters.types.add(t)
+
+
+def test_dump_global_ctx(dsn, crdb_adapters, pgconn):
+ adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ adapters.register_dumper(MyStr, make_dumper("gt"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_load_global_ctx(dsn, crdb_adapters):
+ adapters.register_loader("text", make_loader("gt"))
+ adapters.register_loader("text", make_bin_loader("gb"))
+ with CrdbConnection.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py
new file mode 100644
index 0000000..b2a69ef
--- /dev/null
+++ b/tests/crdb/test_connection.py
@@ -0,0 +1,86 @@
+import time
+import threading
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_is_crdb(conn):
+ assert CrdbConnection.is_crdb(conn)
+ assert CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_connect(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+ with psycopg.crdb.connect(dsn) as conn:
+ assert isinstance(conn, CrdbConnection)
+
+
+def test_xid(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+def test_tpc_begin(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_begin("foo")
+
+
+def test_tpc_recover(dsn):
+ with CrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.tpc_recover()
+
+
+@pytest.mark.slow
+def test_broken_connection(conn):
+ cur = conn.cursor()
+ (session_id,) = cur.execute("select session_id from [show session_id]").fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ cur.execute("cancel session %s", [session_id])
+ assert conn.closed
+
+
+@pytest.mark.slow
+def test_broken(conn):
+ (session_id,) = conn.execute("show session_id").fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("cancel session %s", [session_id])
+
+ assert conn.closed
+ assert conn.broken
+ conn.close()
+ assert conn.closed
+ assert conn.broken
+
+
+@pytest.mark.slow
+def test_identify_closure(conn_cls, dsn):
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ with conn_cls.connect(dsn, autocommit=True) as conn2:
+ (session_id,) = conn.execute("show session_id").fetchone()
+
+ def closer():
+ time.sleep(0.2)
+ conn2.execute("cancel session %s", [session_id])
+
+ t = threading.Thread(target=closer)
+ t.start()
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ # CRDB seems to take not less than 1s
+ assert 0.2 < dt < 2
+ finally:
+ t.join()
diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py
new file mode 100644
index 0000000..b568e42
--- /dev/null
+++ b/tests/crdb/test_connection_async.py
@@ -0,0 +1,85 @@
+import time
+import asyncio
+
+import psycopg.crdb
+from psycopg import errors as e
+from psycopg.crdb import AsyncCrdbConnection
+from psycopg._compat import create_task
+
+import pytest
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+async def test_is_crdb(aconn):
+ assert AsyncCrdbConnection.is_crdb(aconn)
+ assert AsyncCrdbConnection.is_crdb(aconn.pgconn)
+
+
+async def test_connect(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
+
+
+async def test_xid(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ conn.xid(1, "gtrid", "bqual")
+
+
+async def test_tpc_begin(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_begin("foo")
+
+
+async def test_tpc_recover(dsn):
+ async with await AsyncCrdbConnection.connect(dsn) as conn:
+ with pytest.raises(e.NotSupportedError):
+ await conn.tpc_recover()
+
+
+@pytest.mark.slow
+async def test_broken_connection(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select session_id from [show session_id]")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.DatabaseError):
+ await cur.execute("cancel session %s", [session_id])
+ assert aconn.closed
+
+
+@pytest.mark.slow
+async def test_broken(aconn):
+ cur = await aconn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+ with pytest.raises(psycopg.OperationalError):
+ await aconn.execute("cancel session %s", [session_id])
+
+ assert aconn.closed
+ assert aconn.broken
+ await aconn.close()
+ assert aconn.closed
+ assert aconn.broken
+
+
+@pytest.mark.slow
+async def test_identify_closure(aconn_cls, dsn):
+ async with await aconn_cls.connect(dsn) as conn:
+ async with await aconn_cls.connect(dsn) as conn2:
+ cur = await conn.execute("show session_id")
+ (session_id,) = await cur.fetchone()
+
+ async def closer():
+ await asyncio.sleep(0.2)
+ await conn2.execute("cancel session %s", [session_id])
+
+ t = create_task(closer())
+ t0 = time.time()
+ try:
+ with pytest.raises(psycopg.OperationalError):
+ await conn.execute("select pg_sleep(3.0)")
+ dt = time.time() - t0
+ assert 0.2 < dt < 2
+ finally:
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py
new file mode 100644
index 0000000..274a0c0
--- /dev/null
+++ b/tests/crdb/test_conninfo.py
@@ -0,0 +1,21 @@
+import pytest
+
+pytestmark = pytest.mark.crdb
+
+
+def test_vendor(conn):
+ assert conn.info.vendor == "CockroachDB"
+
+
+def test_server_version(conn):
+ assert conn.info.server_version > 200000
+
+
+@pytest.mark.crdb("< 22")
+def test_backend_pid_pre_22(conn):
+ assert conn.info.backend_pid == 0
+
+
+@pytest.mark.crdb(">= 22")
+def test_backend_pid(conn):
+ assert conn.info.backend_pid > 0
diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py
new file mode 100644
index 0000000..b7d26aa
--- /dev/null
+++ b/tests/crdb/test_copy.py
@@ -0,0 +1,233 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg import sql, errors as e
+from psycopg.pq import Format
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import ensure_table, sample_records
+from ..test_copy import sample_tabledef as sample_tabledef_pg
+
+# CRDB int/serial are int8
+sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin with binary") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+def test_copy_big_size_block(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ with cur.copy(f"copy copy_in (col2, data) from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select col2, data from copy_in order by 2").fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def copyopt(format):
+ return "with binary" if format == Format.BINARY else ""
diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py
new file mode 100644
index 0000000..d5fbf50
--- /dev/null
+++ b/tests/crdb/test_copy_async.py
@@ -0,0 +1,235 @@
+import pytest
+import string
+from random import randrange, choice
+
+from psycopg.pq import Format
+from psycopg import sql, errors as e
+from psycopg.adapt import PyFormat
+from psycopg.types.numeric import Int4
+
+from ..utils import eur, gc_collect, gc_count
+from ..test_copy import sample_text, sample_binary # noqa
+from ..test_copy import sample_records
+from ..test_copy_async import ensure_table
+from .test_copy import sample_tabledef, copyopt
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+async def test_copy_in_buffers(aconn, format, buffer):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ await copy.write(globals()[buffer])
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+async def test_copy_in_buffers_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_str(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text.decode())
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.xfail(reason="bad sqlstate - CRDB #81559")
+async def test_copy_in_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ async with cur.copy("copy copy_in from stdin with binary") as copy:
+ await copy.write(sample_text.decode())
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_empty(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}"):
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+async def test_copy_big_size_record(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write_row([data])
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+@pytest.mark.slow
+async def test_copy_big_size_block(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "id serial primary key, data text")
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n"
+ async with cur.copy("copy copy_in (data) from stdin") as copy:
+ await copy.write(copy_data)
+
+ await cur.execute("select data from copy_in limit 1")
+ assert (await cur.fetchone())[0] == data
+
+
+async def test_copy_in_buffers_with_pg_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ await copy.write(sample_text)
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_set_types(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+
+ async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ await copy.write_row(row)
+
+ await cur.execute("select * from copy_in order by 1")
+ data = await cur.fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+async def test_copy_in_records_binary(aconn, format):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 serial primary key, col2 int4, data text")
+
+ async with cur.copy(
+ f"copy copy_in (col2, data) from stdin {copyopt(format)}"
+ ) as copy:
+ for row in sample_records:
+ await copy.write_row((None, row[2]))
+
+ await cur.execute("select col2, data from copy_in order by 2")
+ data = await cur.fetchall()
+ assert data == [(None, "hello"), (None, "world")]
+
+
+@pytest.mark.crdb_skip("copy canceled")
+async def test_copy_in_buffers_with_py_error(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ async with cur.copy("copy copy_in from stdin") as copy:
+ await copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_in_allchars(aconn):
+ cur = aconn.cursor()
+ await ensure_table(cur, "col1 int primary key, col2 int, data text")
+
+ async with cur.copy("copy copy_in from stdin") as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, None, chr(i)))
+ await copy.write_row((ord(eur), None, eur))
+
+ await cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ )
+ data = await cur.fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.crdb_skip("copy array")
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ async def work():
+ async with await aconn_cls.connect(dsn) as conn:
+ async with conn.cursor(binary=fmt) as cur:
+ await cur.execute(faker.drop_stmt)
+ await cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin {}").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL("with binary" if fmt else ""),
+ )
+ async with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ await copy.write_row(row)
+
+ await cur.execute(faker.select_stmt)
+ recs = await cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ await work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py
new file mode 100644
index 0000000..991b084
--- /dev/null
+++ b/tests/crdb/test_cursor.py
@@ -0,0 +1,65 @@
+import json
+import threading
+from uuid import uuid4
+from queue import Queue
+from typing import Any
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.crdb
+
+
+@pytest.fixture
+def testfeed(svcconn):
+ name = f"test_feed_{str(uuid4()).replace('-', '')}"
+ svcconn.execute("set cluster setting kv.rangefeed.enabled to true")
+ svcconn.execute(f"create table {name} (id serial primary key, data text)")
+ yield name
+ svcconn.execute(f"drop table {name}")
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
+ conn.autocommit = True
+ q: "Queue[Any]" = Queue()
+
+ def worker():
+ try:
+ with conn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ for row in cur.stream(f"experimental changefeed for {testfeed}"):
+ q.put(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put(None)
+ except Exception as ex:
+ q.put(ex)
+
+ t = threading.Thread(target=worker)
+ t.start()
+
+ cur = conn.cursor()
+ cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = cur.fetchone()
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = q.get(timeout=1)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = cur.fetchone()
+ cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert q.get(timeout=1) is None
+ t.join()
diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py
new file mode 100644
index 0000000..229295d
--- /dev/null
+++ b/tests/crdb/test_cursor_async.py
@@ -0,0 +1,61 @@
+import json
+import asyncio
+from typing import Any
+from asyncio.queues import Queue
+
+import pytest
+from psycopg import pq, errors as e
+from psycopg.rows import namedtuple_row
+from psycopg._compat import create_task
+
+from .test_cursor import testfeed
+
+testfeed # fixture
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt_out", pq.Format)
+async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
+ await aconn.set_autocommit(True)
+ q: "Queue[Any]" = Queue()
+
+ async def worker():
+ try:
+ async with await aconn_cls.connect(dsn, autocommit=True) as conn:
+ cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
+ try:
+ async for row in cur.stream(
+ f"experimental changefeed for {testfeed}"
+ ):
+ q.put_nowait(row)
+ except e.QueryCanceled:
+ assert conn.info.transaction_status == conn.TransactionStatus.IDLE
+ q.put_nowait(None)
+ except Exception as ex:
+ q.put_nowait(ex)
+
+ t = create_task(worker())
+
+ cur = aconn.cursor()
+ await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
+ (key,) = await cur.fetchone()
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
+
+ await cur.execute(f"delete from {testfeed} where id = %s", [key])
+ row = await asyncio.wait_for(q.get(), 1.0)
+ assert row.table == testfeed
+ assert json.loads(row.key) == [key]
+ assert json.loads(row.value)["after"] is None
+
+ await cur.execute("select query_id from [show statements] where query !~ 'show'")
+ (qid,) = await cur.fetchone()
+ await cur.execute("cancel query %s", [qid])
+ assert cur.statusmessage == "CANCEL QUERIES 1"
+
+ assert await asyncio.wait_for(q.get(), 1.0) is None
+ await asyncio.gather(t)
diff --git a/tests/crdb/test_no_crdb.py b/tests/crdb/test_no_crdb.py
new file mode 100644
index 0000000..df43f3b
--- /dev/null
+++ b/tests/crdb/test_no_crdb.py
@@ -0,0 +1,34 @@
+from psycopg.pq import TransactionStatus
+from psycopg.crdb import CrdbConnection
+
+import pytest
+
+pytestmark = pytest.mark.crdb("skip")
+
+
+def test_is_crdb(conn):
+ assert not CrdbConnection.is_crdb(conn)
+ assert not CrdbConnection.is_crdb(conn.pgconn)
+
+
+def test_tpc_on_pg_connection(conn, tpc):
+ xid = conn.xid(1, "gtrid", "bqual")
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+
+ conn.tpc_begin(xid)
+ assert conn.info.transaction_status == TransactionStatus.INTRANS
+
+ cur = conn.cursor()
+ cur.execute("insert into test_tpc values ('test_tpc_commit')")
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_prepare()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 1
+ assert tpc.count_test_records() == 0
+
+ conn.tpc_commit()
+ assert conn.info.transaction_status == TransactionStatus.IDLE
+ assert tpc.count_xacts() == 0
+ assert tpc.count_test_records() == 1
diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py
new file mode 100644
index 0000000..2cff0a7
--- /dev/null
+++ b/tests/crdb/test_typing.py
@@ -0,0 +1,49 @@
+import pytest
+
+from ..test_typing import _test_reveal
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.crdb.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect()",
+ "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal_crdb(stmts, type, mypy)
+
+
+def _test_reveal_crdb(stmts, type, mypy):
+ stmts = f"""\
+import psycopg.crdb
+{stmts}
+"""
+ _test_reveal(stmts, type, mypy)