diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /tests/test_copy_async.py | |
parent | Initial commit. (diff) | |
download | psycopg3-7acd1f75a918595b0fe5366910f80c2a8527072c.tar.xz psycopg3-7acd1f75a918595b0fe5366910f80c2a8527072c.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_copy_async.py')
-rw-r--r-- | tests/test_copy_async.py | 892 |
1 files changed, 892 insertions, 0 deletions
diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py new file mode 100644 index 0000000..59389dd --- /dev/null +++ b/tests/test_copy_async.py @@ -0,0 +1,892 @@ +import string +import hashlib +from io import BytesIO, StringIO +from random import choice, randrange +from itertools import cycle + +import pytest + +import psycopg +from psycopg import pq +from psycopg import sql +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.copy import AsyncCopy +from psycopg.copy import AsyncWriter, AsyncLibpqWriter, AsyncQueuedLibpqWriter +from psycopg.types import TypeInfo +from psycopg.adapt import PyFormat +from psycopg.types.hstore import register_hstore +from psycopg.types.numeric import Int4 + +from .utils import alist, eur, gc_collect, gc_count +from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa +from .test_copy import sample_values, sample_records, sample_tabledef +from .test_copy import py_to_raw, special_chars + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("copy"), +] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_read(aconn, format): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + for row in want: + got = await copy.read() + assert got == row + assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE + + assert await copy.read() == b"" + assert await copy.read() == b"" + + assert await copy.read() == b"" + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_iter(aconn, format, row_factory): + if format == pq.Format.TEXT: + want = [row + b"\n" for row in sample_text.splitlines()] + else: + want = sample_binary_rows + + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + assert await alist(copy) == want + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_copy_out_no_result(aconn, format, row_factory): + rf = getattr(psycopg.rows, row_factory) + cur = aconn.cursor(row_factory=rf) + async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"): + with pytest.raises(e.ProgrammingError): + await cur.fetchone() + + +@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})]) +async def test_copy_out_param(aconn, ph, params): + cur = aconn.cursor() + async with cur.copy( + f"copy (select * from generate_series(1, {ph})) to stdout", params + ) as copy: + copy.set_types(["int4"]) + assert await alist(copy.rows()) == [(i + 1,) for i in range(10)] + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +@pytest.mark.parametrize("typetype", ["names", "oids"]) +async def test_read_rows(aconn, format, typetype): + cur = aconn.cursor() + async with cur.copy( + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" + ) as copy: + copy.set_types(["int4", "text", "float8[]"]) + row = await copy.read_row() + assert (await copy.read_row()) is None + + assert row == (10, "hello", [0.0, 1.0]) + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +@pytest.mark.parametrize("format", Format) +async def test_rows(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types("int4 int4 text".split()) + rows = await alist(copy.rows()) + + assert rows == sample_records + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_set_custom_type(aconn, hstore): + command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout""" + cur = aconn.cursor() + + async with cur.copy(command) as copy: + rows = await alist(copy.rows()) + + assert rows == [('"a"=>"1", "b"=>"2"',)] + + register_hstore(await TypeInfo.fetch(aconn, "hstore"), cur) + async with cur.copy(command) as copy: + copy.set_types(["hstore"]) + rows = await alist(copy.rows()) + + assert rows == [({"a": "1", "b": "2"},)] + + +@pytest.mark.parametrize("format", Format) +async def test_copy_out_allchars(aconn, format): + cur = aconn.cursor() + chars = list(map(chr, range(1, 256))) + [eur] + await aconn.execute("set client_encoding to utf8") + rows = [] + query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format( + chars, sql.SQL(format.name) + ) + async with cur.copy(query) as copy: + copy.set_types(["text"]) + while True: + row = await copy.read_row() + if not row: + break + assert len(row) == 1 + rows.append(row[0]) + + assert rows == chars + + +@pytest.mark.parametrize("format", Format) +async def test_read_row_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = [] + while True: + row = await copy.read_row() + if not row: + break + rows.append(row) + + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("format", Format) +async def test_rows_notypes(aconn, format): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + rows = await alist(copy.rows()) + ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records] + assert rows == ref + + +@pytest.mark.parametrize("err", [-1, 1]) +@pytest.mark.parametrize("format", Format) +async def test_copy_out_badntypes(aconn, format, err): + cur = aconn.cursor() + async with cur.copy( + f"copy ({sample_values}) to stdout (format {format.name})" + ) as copy: + copy.set_types([0] * (len(sample_records[0]) + err)) + with pytest.raises(e.ProgrammingError): + await copy.read_row() + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_copy_in_buffers(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_buffers_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_bad_result(aconn): + await aconn.set_autocommit(True) + + cur = aconn.cursor() + + with pytest.raises(e.SyntaxError): + async with cur.copy("wat"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("reset timezone"): + pass + + with pytest.raises(e.ProgrammingError): + async with cur.copy("copy (select 1) to stdout; select 1") as copy: + await alist(copy) + + with pytest.raises(e.ProgrammingError): + async with cur.copy("select 1; copy (select 1) to stdout"): + pass + + +async def test_copy_in_str(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text.decode()) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_copy_in_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled): + async with cur.copy("copy copy_in from stdin (format binary)") as copy: + await copy.write(sample_text.decode()) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + assert cur.rowcount == 0 + + +@pytest.mark.slow +async def test_copy_big_size_record(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write_row([data]) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview]) +async def test_copy_big_size_block(aconn, pytype): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n") + async with cur.copy("copy copy_in (data) from stdin") as copy: + await copy.write(copy_data) + + await cur.execute("select data from copy_in limit 1") + assert await cur.fetchone() == (data,) + + +@pytest.mark.parametrize("format", Format) +async def test_subclass_adapter(aconn, format): + if format == Format.TEXT: + from psycopg.types.string import StrDumper as BaseDumper + else: + from psycopg.types.string import ( # type: ignore[no-redef] + StrBinaryDumper as BaseDumper, + ) + + class MyStrDumper(BaseDumper): + def dump(self, obj): + return super().dump(obj) * 2 + + aconn.adapters.register_dumper(str, MyStrDumper) + + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy( + f"copy copy_in (data) from stdin (format {format.name})" + ) as copy: + await copy.write_row(("hello",)) + + await cur.execute("select data from copy_in") + rec = await cur.fetchone() + assert rec[0] == "hellohello" + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_error_empty(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy(f"copy copy_in from stdin (format {format.name})"): + raise Exception("mannaggiamiseria") + + assert "mannaggiamiseria" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_pg_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.UniqueViolation): + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + await copy.write(sample_text) + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_in_buffers_with_py_error(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + with pytest.raises(e.QueryCanceled) as exc: + async with cur.copy("copy copy_in from stdin (format text)") as copy: + await copy.write(sample_text) + raise Exception("nuttengoggenio") + + assert "nuttengoggenio" in str(exc.value) + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_error_with_copy_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS + + +async def test_copy_out_error_with_copy_not_finished(aconn): + cur = aconn.cursor() + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy (select generate_series(1, 1000000)) to stdout" + ) as copy: + await copy.read_row() + 1 / 0 + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +async def test_copy_out_server_error(aconn): + cur = aconn.cursor() + with pytest.raises(e.DivisionByZero): + async with cur.copy( + "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout" + ) as copy: + async for block in copy: + pass + + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + for row in sample_records: + if format == Format.BINARY: + row = tuple( + Int4(i) if isinstance(i, int) else i for i in row + ) # type: ignore[assignment] + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_set_types(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: + copy.set_types(["int4", "int4", "text"]) + for row in sample_records: + await copy.write_row(row) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.parametrize("format", Format) +async def test_copy_in_records_binary(aconn, format): + cur = aconn.cursor() + await ensure_table(cur, "col1 serial primary key, col2 int, data text") + + async with cur.copy( + f"copy copy_in (col2, data) from stdin (format {format.name})" + ) as copy: + for row in sample_records: + await copy.write_row((None, row[2])) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == [(1, None, "hello"), (2, None, "world")] + + +async def test_copy_in_allchars(aconn): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + + await aconn.execute("set client_encoding to utf8") + async with cur.copy("copy copy_in from stdin (format text)") as copy: + for i in range(1, 256): + await copy.write_row((i, None, chr(i))) + await copy.write_row((ord(eur), None, eur)) + + await cur.execute( + """ +select col1 = ascii(data), col2 is null, length(data), count(*) +from copy_in group by 1, 2, 3 +""" + ) + data = await cur.fetchall() + assert data == [(True, True, 1, 256)] + + +async def test_copy_in_format(aconn): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy: + for i in range(1, 256): + await copy.write_row((i, chr(i))) + + file.seek(0) + rows = file.read().split(b"\n") + assert not rows[-1] + del rows[-1] + + for i, row in enumerate(rows, start=1): + fields = row.split(b"\t") + assert len(fields) == 2 + assert int(fields[0].decode()) == i + if i in special_chars: + assert fields[1].decode() == f"\\{special_chars[i]}" + else: + assert fields[1].decode() == chr(i) + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_file_writer(aconn, format, buffer): + file = BytesIO() + await aconn.execute("set client_encoding to utf8") + cur = aconn.cursor() + async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy: + for record in sample_records: + await copy.write_row(record) + + file.seek(0) + want = globals()[buffer] + got = file.read() + assert got == want + + +@pytest.mark.slow +async def test_copy_from_to(aconn): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview]) +async def test_copy_from_to_bytes(aconn, pytype): + # Roundtrip from file to database to file blockwise + gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(pytype(block.encode())) + + await gen.assert_data() + + f = BytesIO() + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + f.write(block) + + f.seek(0) + assert gen.sha(f) == gen.sha(gen.file()) + + +@pytest.mark.slow +async def test_copy_from_insane_size(aconn): + # Trying to trigger a "would block" error + gen = DataGenerator( + aconn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024 + ) + await gen.ensure_table() + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + + await gen.assert_data() + + +async def test_copy_rowcount(aconn): + gen = DataGenerator(aconn, nrecs=3, srec=10) + await gen.ensure_table() + + cur = aconn.cursor() + async with cur.copy("copy copy_in from stdin") as copy: + for block in gen.blocks(): + await copy.write(block) + assert cur.rowcount == 3 + + gen = DataGenerator(aconn, nrecs=2, srec=10, offset=3) + async with cur.copy("copy copy_in from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == 2 + + async with cur.copy("copy copy_in to stdout") as copy: + async for block in copy: + pass + assert cur.rowcount == 5 + + with pytest.raises(e.BadCopyFileFormat): + async with cur.copy("copy copy_in (id) from stdin") as copy: + for rec in gen.records(): + await copy.write_row(rec) + assert cur.rowcount == -1 + + +async def test_copy_query(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert cur._query.query == b"copy (select 1) to stdout" + assert not cur._query.params + await alist(copy) + + +async def test_cant_reenter(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + await alist(copy) + + with pytest.raises(TypeError): + async with copy: + await alist(copy) + + +async def test_str(aconn): + cur = aconn.cursor() + async with cur.copy("copy (select 1) to stdout") as copy: + assert "[ACTIVE]" in str(copy) + await alist(copy) + + assert "[INTRANS]" in str(copy) + + +async def test_description(aconn): + async with aconn.cursor() as cur: + async with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy: + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + await alist(copy.rows()) + + len(cur.description) == 3 + assert cur.description[0].name == "column_1" + assert cur.description[2].name == "column_3" + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_worker_life(aconn, format, buffer): + cur = aconn.cursor() + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", + writer=AsyncQueuedLibpqWriter(cur), + ) as copy: + assert not copy.writer._worker + await copy.write(globals()[buffer]) + assert copy.writer._worker + + assert not copy.writer._worker + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +async def test_worker_error_propagated(aconn, monkeypatch): + def copy_to_broken(pgconn, buffer): + raise ZeroDivisionError + yield + + monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + cur = aconn.cursor() + await cur.execute("create temp table wat (a text, b text)") + with pytest.raises(ZeroDivisionError): + async with cur.copy( + "copy wat from stdin", writer=AsyncQueuedLibpqWriter(cur) + ) as copy: + await copy.write("a,b") + + +@pytest.mark.parametrize( + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] +) +async def test_connection_writer(aconn, format, buffer): + cur = aconn.cursor() + writer = AsyncLibpqWriter(cur) + + await ensure_table(cur, sample_tabledef) + async with cur.copy( + f"copy copy_in from stdin (format {format.name})", writer=writer + ) as copy: + assert copy.writer is writer + await copy.write(globals()[buffer]) + + await cur.execute("select * from copy_in order by 1") + data = await cur.fetchall() + assert data == sample_records + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) +async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + async with faker.find_insert_problem_async(conn): + await cur.executemany(faker.insert_stmt, faker.records) + + stmt = sql.SQL( + "copy (select {} from {} order by id) to stdout (format {})" + ).format( + sql.SQL(", ").join(faker.fields_names), + faker.table_name, + sql.SQL(fmt.name), + ) + + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + + if method == "read": + while True: + tmp = await copy.read() + if not tmp: + break + elif method == "iter": + await alist(copy) + elif method == "row": + while True: + tmp = await copy.read_row() + if tmp is None: + break + elif method == "rows": + await alist(copy.rows()) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize( + "fmt, set_types", + [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], +) +async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types): + faker.format = PyFormat.from_pq(fmt) + faker.choose_schema(ncols=20) + faker.make_records(20) + + async def work(): + async with await aconn_cls.connect(dsn) as conn: + async with conn.cursor(binary=fmt) as cur: + await cur.execute(faker.drop_stmt) + await cur.execute(faker.create_stmt) + + stmt = sql.SQL("copy {} ({}) from stdin (format {})").format( + faker.table_name, + sql.SQL(", ").join(faker.fields_names), + sql.SQL(fmt.name), + ) + async with cur.copy(stmt) as copy: + if set_types: + copy.set_types(faker.types_names) + for row in faker.records: + await copy.write_row(row) + + await cur.execute(faker.select_stmt) + recs = await cur.fetchall() + + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + gc_collect() + n = [] + for i in range(3): + await work() + gc_collect() + n.append(gc_count()) + + assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +@pytest.mark.slow +@pytest.mark.parametrize("mode", ["row", "block", "binary"]) +async def test_copy_table_across(aconn_cls, dsn, faker, mode): + faker.choose_schema(ncols=20) + faker.make_records(20) + + connect = aconn_cls.connect + async with await connect(dsn) as conn1, await connect(dsn) as conn2: + faker.table_name = sql.Identifier("copy_src") + await conn1.execute(faker.drop_stmt) + await conn1.execute(faker.create_stmt) + await conn1.cursor().executemany(faker.insert_stmt, faker.records) + + faker.table_name = sql.Identifier("copy_tgt") + await conn2.execute(faker.drop_stmt) + await conn2.execute(faker.create_stmt) + + fmt = "(format binary)" if mode == "binary" else "" + async with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1: + async with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2: + if mode == "row": + async for row in copy1.rows(): + await copy2.write_row(row) + else: + async for data in copy1: + await copy2.write(data) + + cur = await conn2.execute(faker.select_stmt) + recs = await cur.fetchall() + for got, want in zip(recs, faker.records): + faker.assert_record(got, want) + + +async def ensure_table(cur, tabledef, name="copy_in"): + await cur.execute(f"drop table if exists {name}") + await cur.execute(f"create table {name} ({tabledef})") + + +class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): + self.conn = conn + self.nrecs = nrecs + self.srec = srec + self.offset = offset + self.block_size = block_size + + async def ensure_table(self): + cur = self.conn.cursor() + await ensure_table(cur, "id integer primary key, data text") + + def records(self): + for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)): + s = c * self.srec + yield (i + self.offset, s) + + def file(self): + f = StringIO() + for i, s in self.records(): + f.write("%s\t%s\n" % (i, s)) + + f.seek(0) + return f + + def blocks(self): + f = self.file() + while True: + block = f.read(self.block_size) + if not block: + break + yield block + + async def assert_data(self): + cur = self.conn.cursor() + await cur.execute("select id, data from copy_in order by id") + for record in self.records(): + assert record == await cur.fetchone() + + assert await cur.fetchone() is None + + def sha(self, f): + m = hashlib.sha256() + while True: + block = f.read() + if not block: + break + if isinstance(block, str): + block = block.encode() + m.update(block) + return m.hexdigest() + + +class AsyncFileWriter(AsyncWriter): + def __init__(self, file): + self.file = file + + async def write(self, data): + self.file.write(data) |