summaryrefslogtreecommitdiffstats
path: root/tests/test_copy.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_copy.py889
1 files changed, 889 insertions, 0 deletions
diff --git a/tests/test_copy.py b/tests/test_copy.py
new file mode 100644
index 0000000..17cf2fc
--- /dev/null
+++ b/tests/test_copy.py
@@ -0,0 +1,889 @@
+import string
+import struct
+import hashlib
+from io import BytesIO, StringIO
+from random import choice, randrange
+from itertools import cycle
+
+import pytest
+
+import psycopg
+from psycopg import pq
+from psycopg import sql
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter
+from psycopg.adapt import PyFormat
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import register_hstore
+from psycopg.types.numeric import Int4
+
+from .utils import eur, gc_collect, gc_count
+
+pytestmark = pytest.mark.crdb_skip("copy")
+
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
+sample_tabledef = "col1 serial primary key, col2 int, data text"
+
+sample_text = b"""\
+40010\t40020\thello
+40040\t\\N\tworld
+"""
+
+sample_binary_str = """
+5047 434f 5059 0aff 0d0a 00
+00 0000 0000 0000 00
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
+
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
+
+ff ff
+"""
+
+sample_binary_rows = [
+ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
+]
+sample_binary = b"".join(sample_binary_rows)
+
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_read(conn, format):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ for row in want:
+ got = copy.read()
+ assert got == row
+ assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE
+
+ assert copy.read() == b""
+ assert copy.read() == b""
+
+ assert copy.read() == b""
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_iter(conn, format, row_factory):
+ if format == pq.Format.TEXT:
+ want = [row + b"\n" for row in sample_text.splitlines()]
+ else:
+ want = sample_binary_rows
+
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ assert list(copy) == want
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_no_result(conn, format, row_factory):
+ rf = getattr(psycopg.rows, row_factory)
+ cur = conn.cursor(row_factory=rf)
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+ with pytest.raises(e.ProgrammingError):
+ cur.fetchone()
+
+
+@pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
+def test_copy_out_param(conn, ph, params):
+ cur = conn.cursor()
+ with cur.copy(
+ f"copy (select * from generate_series(1, {ph})) to stdout", params
+ ) as copy:
+ copy.set_types(["int4"])
+ assert list(copy.rows()) == [(i + 1,) for i in range(10)]
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+def test_read_rows(conn, format, typetype):
+ cur = conn.cursor()
+ with cur.copy(
+ f"""copy (
+ select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+ ) to stdout (format {format.name})"""
+ ) as copy:
+ copy.set_types(["int4", "text", "float8[]"])
+ row = copy.read_row()
+ assert copy.read_row() is None
+
+ assert row == (10, "hello", [0.0, 1.0])
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ rows = list(copy.rows())
+
+ assert rows == sample_records
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_set_custom_type(conn, hstore):
+ command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+ cur = conn.cursor()
+
+ with cur.copy(command) as copy:
+ rows = list(copy.rows())
+
+ assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+ register_hstore(TypeInfo.fetch(conn, "hstore"), cur)
+ with cur.copy(command) as copy:
+ copy.set_types(["hstore"])
+ rows = list(copy.rows())
+
+ assert rows == [({"a": "1", "b": "2"},)]
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_allchars(conn, format):
+ cur = conn.cursor()
+ chars = list(map(chr, range(1, 256))) + [eur]
+ conn.execute("set client_encoding to utf8")
+ rows = []
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
+ with cur.copy(query) as copy:
+ copy.set_types(["text"])
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ assert len(row) == 1
+ rows.append(row[0])
+
+ assert rows == chars
+
+
+@pytest.mark.parametrize("format", Format)
+def test_read_row_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = []
+ while True:
+ row = copy.read_row()
+ if not row:
+ break
+ rows.append(row)
+
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("format", Format)
+def test_rows_notypes(conn, format):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ rows = list(copy.rows())
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
+ assert rows == ref
+
+
+@pytest.mark.parametrize("err", [-1, 1])
+@pytest.mark.parametrize("format", Format)
+def test_copy_out_badntypes(conn, format, err):
+ cur = conn.cursor()
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
+ copy.set_types([0] * (len(sample_records[0]) + err))
+ with pytest.raises(e.ProgrammingError):
+ copy.read_row()
+
+
+@pytest.mark.parametrize(
+ "format, buffer",
+ [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+)
+def test_copy_in_buffers(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_buffers_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_bad_result(conn):
+ conn.autocommit = True
+
+ cur = conn.cursor()
+
+ with pytest.raises(e.SyntaxError):
+ with cur.copy("wat"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("reset timezone"):
+ pass
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("copy (select 1) to stdout; select 1") as copy:
+ list(copy)
+
+ with pytest.raises(e.ProgrammingError):
+ with cur.copy("select 1; copy (select 1) to stdout"):
+ pass
+
+
+def test_copy_in_str(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text.decode())
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_copy_in_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled):
+ with cur.copy("copy copy_in from stdin (format binary)") as copy:
+ copy.write(sample_text.decode())
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+ assert cur.rowcount == 0
+
+
+@pytest.mark.slow
+def test_copy_big_size_record(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write_row([data])
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+def test_copy_big_size_block(conn, pytype):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
+ with cur.copy("copy copy_in (data) from stdin") as copy:
+ copy.write(copy_data)
+
+ cur.execute("select data from copy_in limit 1")
+ assert cur.fetchone()[0] == data
+
+
+@pytest.mark.parametrize("format", Format)
+def test_subclass_adapter(conn, format):
+ if format == Format.TEXT:
+ from psycopg.types.string import StrDumper as BaseDumper
+ else:
+ from psycopg.types.string import ( # type: ignore[no-redef]
+ StrBinaryDumper as BaseDumper,
+ )
+
+ class MyStrDumper(BaseDumper):
+ def dump(self, obj):
+ return super().dump(obj) * 2
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
+ copy.write_row(("hello",))
+
+ rec = cur.execute("select data from copy_in").fetchone()
+ assert rec[0] == "hellohello"
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_error_empty(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy(f"copy copy_in from stdin (format {format.name})"):
+ raise Exception("mannaggiamiseria")
+
+ assert "mannaggiamiseria" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_pg_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.UniqueViolation):
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ copy.write(sample_text)
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_in_buffers_with_py_error(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with pytest.raises(e.QueryCanceled) as exc:
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ copy.write(sample_text)
+ raise Exception("nuttengoggenio")
+
+ assert "nuttengoggenio" in str(exc.value)
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_error_with_copy_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_copy_out_error_with_copy_not_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_server_error(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ for block in copy:
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ for row in sample_records:
+ if format == Format.BINARY:
+ row = tuple(
+ Int4(i) if isinstance(i, int) else i for i in row
+ ) # type: ignore[assignment]
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_set_types(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
+ copy.set_types(["int4", "int4", "text"])
+ for row in sample_records:
+ copy.write_row(row)
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.parametrize("format", Format)
+def test_copy_in_records_binary(conn, format):
+ cur = conn.cursor()
+ ensure_table(cur, "col1 serial primary key, col2 int, data text")
+
+ with cur.copy(
+ f"copy copy_in (col2, data) from stdin (format {format.name})"
+ ) as copy:
+ for row in sample_records:
+ copy.write_row((None, row[2]))
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == [(1, None, "hello"), (2, None, "world")]
+
+
+def test_copy_in_allchars(conn):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+
+ conn.execute("set client_encoding to utf8")
+ with cur.copy("copy copy_in from stdin (format text)") as copy:
+ for i in range(1, 256):
+ copy.write_row((i, None, chr(i)))
+ copy.write_row((ord(eur), None, eur))
+
+ data = cur.execute(
+ """
+select col1 = ascii(data), col2 is null, length(data), count(*)
+from copy_in group by 1, 2, 3
+"""
+ ).fetchall()
+ assert data == [(True, True, 1, 256)]
+
+
+def test_copy_in_format(conn):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, writer=FileWriter(file)) as copy:
+ for i in range(1, 256):
+ copy.write_row((i, chr(i)))
+
+ file.seek(0)
+ rows = file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_file_writer(conn, format, buffer):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, binary=format, writer=FileWriter(file)) as copy:
+ for record in sample_records:
+ copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
+@pytest.mark.slow
+def test_copy_from_to(conn):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_copy_from_to_bytes(conn, pytype):
+ # Roundtrip from file to database to file blockwise
+ gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(pytype(block.encode()))
+
+ gen.assert_data()
+
+ f = BytesIO()
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ f.write(block)
+
+ f.seek(0)
+ assert gen.sha(f) == gen.sha(gen.file())
+
+
+@pytest.mark.slow
+def test_copy_from_insane_size(conn):
+ # Trying to trigger a "would block" error
+ gen = DataGenerator(
+ conn, nrecs=4 * 1024, srec=10 * 1024, block_size=20 * 1024 * 1024
+ )
+ gen.ensure_table()
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+
+ gen.assert_data()
+
+
+def test_copy_rowcount(conn):
+ gen = DataGenerator(conn, nrecs=3, srec=10)
+ gen.ensure_table()
+
+ cur = conn.cursor()
+ with cur.copy("copy copy_in from stdin") as copy:
+ for block in gen.blocks():
+ copy.write(block)
+ assert cur.rowcount == 3
+
+ gen = DataGenerator(conn, nrecs=2, srec=10, offset=3)
+ with cur.copy("copy copy_in from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == 2
+
+ with cur.copy("copy copy_in to stdout") as copy:
+ for block in copy:
+ pass
+ assert cur.rowcount == 5
+
+ with pytest.raises(e.BadCopyFileFormat):
+ with cur.copy("copy copy_in (id) from stdin") as copy:
+ for rec in gen.records():
+ copy.write_row(rec)
+ assert cur.rowcount == -1
+
+
+def test_copy_query(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert cur._query.query == b"copy (select 1) to stdout"
+ assert not cur._query.params
+ list(copy)
+
+
+def test_cant_reenter(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ list(copy)
+
+ with pytest.raises(TypeError):
+ with copy:
+ list(copy)
+
+
+def test_str(conn):
+ cur = conn.cursor()
+ with cur.copy("copy (select 1) to stdout") as copy:
+ assert "[ACTIVE]" in str(copy)
+ list(copy)
+
+ assert "[INTRANS]" in str(copy)
+
+
+def test_description(conn):
+ with conn.cursor() as cur:
+ with cur.copy("copy (select 'This', 'Is', 'Text') to stdout") as copy:
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+ list(copy.rows())
+
+ len(cur.description) == 3
+ assert cur.description[0].name == "column_1"
+ assert cur.description[2].name == "column_3"
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_worker_life(conn, format, buffer):
+ cur = conn.cursor()
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=QueuedLibpqDriver(cur)
+ ) as copy:
+ assert not copy.writer._worker
+ copy.write(globals()[buffer])
+ assert copy.writer._worker
+
+ assert not copy.writer._worker
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+def test_worker_error_propagated(conn, monkeypatch):
+ def copy_to_broken(pgconn, buffer):
+ raise ZeroDivisionError
+ yield
+
+ monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+ cur = conn.cursor()
+ cur.execute("create temp table wat (a text, b text)")
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy wat from stdin", writer=QueuedLibpqDriver(cur)) as copy:
+ copy.write("a,b")
+
+
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_connection_writer(conn, format, buffer):
+ cur = conn.cursor()
+ writer = LibpqWriter(cur)
+
+ ensure_table(cur, sample_tabledef)
+ with cur.copy(
+ f"copy copy_in from stdin (format {format.name})", writer=writer
+ ) as copy:
+ assert copy.writer is writer
+ copy.write(globals()[buffer])
+
+ data = cur.execute("select * from copy_in order by 1").fetchall()
+ assert data == sample_records
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+@pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
+def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ stmt = sql.SQL(
+ "copy (select {} from {} order by id) to stdout (format {})"
+ ).format(
+ sql.SQL(", ").join(faker.fields_names),
+ faker.table_name,
+ sql.SQL(fmt.name),
+ )
+
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+
+ if method == "read":
+ while True:
+ tmp = copy.read()
+ if not tmp:
+ break
+ elif method == "iter":
+ list(copy)
+ elif method == "row":
+ while True:
+ tmp = copy.read_row()
+ if tmp is None:
+ break
+ elif method == "rows":
+ list(copy.rows())
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+ "fmt, set_types",
+ [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+)
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+ faker.format = PyFormat.from_pq(fmt)
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ def work():
+ with conn_cls.connect(dsn) as conn:
+ with conn.cursor(binary=fmt) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+
+ stmt = sql.SQL("copy {} ({}) from stdin (format {})").format(
+ faker.table_name,
+ sql.SQL(", ").join(faker.fields_names),
+ sql.SQL(fmt.name),
+ )
+ with cur.copy(stmt) as copy:
+ if set_types:
+ copy.set_types(faker.types_names)
+ for row in faker.records:
+ copy.write_row(row)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+ gc_collect()
+ n = []
+ for i in range(3):
+ work()
+ gc_collect()
+ n.append(gc_count())
+
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("mode", ["row", "block", "binary"])
+def test_copy_table_across(conn_cls, dsn, faker, mode):
+ faker.choose_schema(ncols=20)
+ faker.make_records(20)
+
+ with conn_cls.connect(dsn) as conn1, conn_cls.connect(dsn) as conn2:
+ faker.table_name = sql.Identifier("copy_src")
+ conn1.execute(faker.drop_stmt)
+ conn1.execute(faker.create_stmt)
+ conn1.cursor().executemany(faker.insert_stmt, faker.records)
+
+ faker.table_name = sql.Identifier("copy_tgt")
+ conn2.execute(faker.drop_stmt)
+ conn2.execute(faker.create_stmt)
+
+ fmt = "(format binary)" if mode == "binary" else ""
+ with conn1.cursor().copy(f"copy copy_src to stdout {fmt}") as copy1:
+ with conn2.cursor().copy(f"copy copy_tgt from stdin {fmt}") as copy2:
+ if mode == "row":
+ for row in copy1.rows():
+ copy2.write_row(row)
+ else:
+ for data in copy1:
+ copy2.write(data)
+
+ recs = conn2.execute(faker.select_stmt).fetchall()
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+def py_to_raw(item, fmt):
+ """Convert from Python type to the expected result from the db"""
+ if fmt == Format.TEXT:
+ if isinstance(item, int):
+ return str(item)
+ else:
+ if isinstance(item, int):
+ # Assume int4
+ return struct.pack("!i", item)
+ elif isinstance(item, str):
+ return item.encode()
+ return item
+
+
+def ensure_table(cur, tabledef, name="copy_in"):
+ cur.execute(f"drop table if exists {name}")
+ cur.execute(f"create table {name} ({tabledef})")
+
+
+class DataGenerator:
+ def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
+ self.conn = conn
+ self.nrecs = nrecs
+ self.srec = srec
+ self.offset = offset
+ self.block_size = block_size
+
+ def ensure_table(self):
+ cur = self.conn.cursor()
+ ensure_table(cur, "id integer primary key, data text")
+
+ def records(self):
+ for i, c in zip(range(self.nrecs), cycle(string.ascii_letters)):
+ s = c * self.srec
+ yield (i + self.offset, s)
+
+ def file(self):
+ f = StringIO()
+ for i, s in self.records():
+ f.write("%s\t%s\n" % (i, s))
+
+ f.seek(0)
+ return f
+
+ def blocks(self):
+ f = self.file()
+ while True:
+ block = f.read(self.block_size)
+ if not block:
+ break
+ yield block
+
+ def assert_data(self):
+ cur = self.conn.cursor()
+ cur.execute("select id, data from copy_in order by id")
+ for record in self.records():
+ assert record == cur.fetchone()
+
+ assert cur.fetchone() is None
+
+ def sha(self, f):
+ m = hashlib.sha256()
+ while True:
+ block = f.read()
+ if not block:
+ break
+ if isinstance(block, str):
+ block = block.encode()
+ m.update(block)
+ return m.hexdigest()