summaryrefslogtreecommitdiffstats
path: root/tests/test_generators.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_generators.py')
-rw-r--r--tests/test_generators.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/tests/test_generators.py b/tests/test_generators.py
new file mode 100644
index 0000000..8aba73f
--- /dev/null
+++ b/tests/test_generators.py
@@ -0,0 +1,156 @@
+from collections import deque
+from functools import partial
+from typing import List
+
+import pytest
+
+import psycopg
+from psycopg import waiting
+from psycopg import pq
+
+
+@pytest.fixture
+def pipeline(pgconn):
+ nb, pgconn.nonblocking = pgconn.nonblocking, True
+ assert pgconn.nonblocking
+ pgconn.enter_pipeline_mode()
+ yield
+ if pgconn.pipeline_status:
+ pgconn.exit_pipeline_mode()
+ pgconn.nonblocking = nb
+
+
+def _run_pipeline_communicate(pgconn, generators, commands, expected_statuses):
+ actual_statuses: List[pq.ExecStatus] = []
+ while len(actual_statuses) != len(expected_statuses):
+ if commands:
+ gen = generators.pipeline_communicate(pgconn, commands)
+ results = waiting.wait(gen, pgconn.socket)
+ for (result,) in results:
+ actual_statuses.append(result.status)
+ else:
+ gen = generators.fetch_many(pgconn)
+ results = waiting.wait(gen, pgconn.socket)
+ for result in results:
+ actual_statuses.append(result.status)
+
+ assert actual_statuses == expected_statuses
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_multi_pipeline(pgconn, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, b"select 1", None),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, b"select 2", None),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.mark.pipeline
+def test_pipeline_communicate_no_sync(pgconn, pipeline, generators):
+ numqueries = 10
+ commands = deque(
+ [partial(pgconn.send_query_params, b"select repeat('xyzxz', 12)", None)]
+ * numqueries
+ + [pgconn.send_flush_request]
+ )
+ expected_statuses = [pq.ExecStatus.TUPLES_OK] * numqueries
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+
+
+@pytest.fixture
+def pipeline_demo(pgconn):
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline(" b" id serial primary key, itemno integer)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ yield "pg_pipeline"
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+
+
+# TODOCRDB: 1 doesn't get rolled back. Open a ticket?
+@pytest.mark.pipeline
+@pytest.mark.crdb("skip", reason="pipeline aborted")
+def test_pipeline_communicate_abort(pgconn, pipeline_demo, pipeline, generators):
+ insert_sql = b"insert into pg_pipeline(itemno) values ($1)"
+ commands = deque(
+ [
+ partial(pgconn.send_query_params, insert_sql, [b"1"]),
+ partial(pgconn.send_query_params, b"select no_such_function(1)", None),
+ partial(pgconn.send_query_params, insert_sql, [b"2"]),
+ pgconn.pipeline_sync,
+ partial(pgconn.send_query_params, insert_sql, [b"3"]),
+ pgconn.pipeline_sync,
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_SYNC,
+ pq.ExecStatus.COMMAND_OK,
+ pq.ExecStatus.PIPELINE_SYNC,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)
+ pgconn.exit_pipeline_mode()
+ res = pgconn.exec_(b"select itemno from pg_pipeline order by itemno")
+ assert res.ntuples == 1
+ assert res.get_value(0, 0) == b"3"
+
+
+@pytest.fixture
+def pipeline_uniqviol(pgconn):
+ if not psycopg.Pipeline.is_supported():
+ pytest.skip(psycopg.Pipeline._not_supported_reason())
+ assert pgconn.pipeline_status == 0
+ res = pgconn.exec_(b"DROP TABLE IF EXISTS pg_pipeline_uniqviol")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(
+ b"CREATE UNLOGGED TABLE pg_pipeline_uniqviol("
+ b" id bigint primary key, idata bigint)"
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.exec_(b"BEGIN")
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ res = pgconn.prepare(
+ b"insertion",
+ b"insert into pg_pipeline_uniqviol values ($1, $2) returning id",
+ )
+ assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
+ return "pg_pipeline_uniqviol"
+
+
+def test_pipeline_communicate_uniqviol(pgconn, pipeline_uniqviol, pipeline, generators):
+ commands = deque(
+ [
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"2", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"1", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"3", b"2"]),
+ partial(pgconn.send_query_prepared, b"insertion", [b"4", b"2"]),
+ partial(pgconn.send_query_params, b"commit", None),
+ ]
+ )
+ expected_statuses = [
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.TUPLES_OK,
+ pq.ExecStatus.FATAL_ERROR,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ pq.ExecStatus.PIPELINE_ABORTED,
+ ]
+ _run_pipeline_communicate(pgconn, generators, commands, expected_statuses)