From 506ed8899b3a97e512be3fd6d44d5b11463bf9bf Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 4 May 2024 19:41:08 +0200 Subject: Adding upstream version 3.1.7. Signed-off-by: Daniel Baumann --- tests/test_generators.py | 156 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 tests/test_generators.py (limited to 'tests/test_generators.py') diff --git a/tests/test_generators.py b/tests/test_generators.py new file mode 100644 index 0000000..8aba73f --- /dev/null +++ b/tests/test_generators.py @@ -0,0 +1,156 @@ +from collections import deque +from functools import partial +from typing import List + +import pytest + +import psycopg +from psycopg import waiting +from psycopg import pq + + +@pytest.fixture +def pipeline(pgconn): + nb, pgconn.nonblocking = pgconn.nonblocking, True + assert pgconn.nonblocking + pgconn.enter_pipeline_mode() + yield + if pgconn.pipeline_status: + pgconn.exit_pipeline_mode() + pgconn.nonblocking = nb + + +def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses): + actual_statuses: List[pq.ExecStatus] = [] + while len(actual_statuses) != len(expected_statuses): + if commands: + gen = generators.pipeline_communicate(pgconn, commands) + results = waiting.wait(gen, pgconn.socket) + for (result,) in results: + actual_statuses.append(result.status) + else: + gen = generators.fetch_many(pgconn) + results = waiting.wait(gen, pgconn.socket) + for result in results: + actual_statuses.append(result.status) + + assert actual_statuses == expected_statuses + + +@pytest.mark.pipeline +def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_params, b"select 1", None), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, b"select 2", None), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.mark.pipeline +def test_pipeline_communicate_no_sync(pgconn, pipeline, generators): + numqueries = 10 + commands = deque( + [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)] + * numqueries + + [pgconn.send_flush_request] + ) + expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + + +@pytest.fixture +def pipeline_demo(pgconn): + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + yield "pg_pipeline" + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + + +# TODOCRDB: 1 doesn't get rolled back. Open a ticket? +@pytest.mark.pipeline +@pytest.mark.crdb("skip", reason="pipeline aborted") +def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators): + insert_sql = b"insert into pg_pipeline(itemno) values ($1)" + commands = deque( + [ + partial(pgconn.send_query_params, insert_sql, [b"1"]), + partial(pgconn.send_query_params, b"select no_such_function(1)", None), + partial(pgconn.send_query_params, insert_sql, [b"2"]), + pgconn.pipeline_sync, + partial(pgconn.send_query_params, insert_sql, [b"3"]), + pgconn.pipeline_sync, + ] + ) + expected_statuses = [ + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_SYNC, + pq.ExecStatus.COMMAND_OK, + pq.ExecStatus.PIPELINE_SYNC, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) + pgconn.exit_pipeline_mode() + res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno") + assert res.ntuples == 1 + assert res.get_value(0, 0) == b"3" + + +@pytest.fixture +def pipeline_uniqviol(pgconn): + if not psycopg.Pipeline.is_supported(): + pytest.skip(psycopg.Pipeline._not_supported_reason()) + assert pgconn.pipeline_status == 0 + res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_( + b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol(" + b" id bigint primary key, idata bigint)" + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.exec_(b"BEGIN") + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + res = pgconn.prepare( + b"insertion", + b"insert into pg_pipeline_uniqviol values ($1, $2) returning id", + ) + assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message + return "pg_pipeline_uniqviol" + + +def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators): + commands = deque( + [ + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]), + partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]), + partial(pgconn.send_query_params, b"commit", None), + ] + ) + expected_statuses = [ + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.TUPLES_OK, + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + pq.ExecStatus.PIPELINE_ABORTED, + ] + _run_pipeline_communicate(pgconn, generators, commands, expected_statuses) -- cgit v1.2.3