diff options
Diffstat (limited to '')
-rw-r--r-- | tests/scripts/pipeline-demo.py | 340 |
1 files changed, 340 insertions, 0 deletions
diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py new file mode 100644 index 0000000..ec95229 --- /dev/null +++ b/tests/scripts/pipeline-demo.py @@ -0,0 +1,340 @@ +"""Pipeline mode demo + +This reproduces libpq_pipeline::pipelined_insert PostgreSQL test at +src/test/modules/libpq_pipeline/libpq_pipeline.c::test_pipelined_insert(). + +We do not fetch results explicitly (using cursor.fetch*()), this is +handled by execute() calls when pgconn socket is read-ready, which +happens when the output buffer is full. +""" +import argparse +import asyncio +import logging +from contextlib import contextmanager +from functools import partial +from typing import Any, Iterator, Optional, Sequence, Tuple + +from psycopg import AsyncConnection, Connection +from psycopg import pq, waiting +from psycopg import errors as e +from psycopg.abc import PipelineCommand +from psycopg.generators import pipeline_communicate +from psycopg.pq import Format, DiagnosticField +from psycopg._compat import Deque + +psycopg_logger = logging.getLogger("psycopg") +pipeline_logger = logging.getLogger("pipeline") +args: argparse.Namespace + + +class LoggingPGconn: + """Wrapper for PGconn that logs fetched results.""" + + def __init__(self, pgconn: pq.abc.PGconn, logger: logging.Logger): + self._pgconn = pgconn + self._logger = logger + + def log_notice(result: pq.abc.PGresult) -> None: + def get_field(field: DiagnosticField) -> Optional[str]: + value = result.error_field(field) + return value.decode("utf-8", "replace") if value else None + + logger.info( + "notice %s %s", + get_field(DiagnosticField.SEVERITY), + get_field(DiagnosticField.MESSAGE_PRIMARY), + ) + + pgconn.notice_handler = log_notice + + if args.trace: + self._trace_file = open(args.trace, "w") + pgconn.trace(self._trace_file.fileno()) + + def __del__(self) -> None: + if hasattr(self, "_trace_file"): + self._pgconn.untrace() + self._trace_file.close() + + def __getattr__(self, name: str) -> Any: + return getattr(self._pgconn, name) + + def send_query(self, command: bytes) -> None: + self._logger.warning("PQsendQuery broken in libpq 14.5") + self._pgconn.send_query(command) + self._logger.info("sent %s", command.decode()) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_params( + command, param_values, param_types, param_formats, result_format + ) + self._logger.info("sent %s", command.decode()) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[bytes]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + self._pgconn.send_query_prepared( + name, param_values, param_formats, result_format + ) + self._logger.info("sent prepared '%s' with %s", name.decode(), param_values) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + self._pgconn.send_prepare(name, command, param_types) + self._logger.info("prepare %s as '%s'", command.decode(), name.decode()) + + def get_result(self) -> Optional[pq.abc.PGresult]: + r = self._pgconn.get_result() + if r is not None: + self._logger.info("got %s result", pq.ExecStatus(r.status).name) + return r + + +@contextmanager +def prepare_pipeline_demo_pq( + pgconn: LoggingPGconn, rows_to_send: int, logger: logging.Logger +) -> Iterator[Tuple[Deque[PipelineCommand], Deque[str]]]: + """Set up pipeline demo with initial queries and yield commands and + results queue for pipeline_communicate(). + """ + logger.debug("enter pipeline") + pgconn.enter_pipeline_mode() + + setup_queries = [ + ("begin", "BEGIN TRANSACTION"), + ("drop table", "DROP TABLE IF EXISTS pq_pipeline_demo"), + ( + "create table", + ( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ), + ), + ( + "prepare", + ("INSERT INTO pq_pipeline_demo(itemno, int8filler)" " VALUES ($1, $2)"), + ), + ] + + commands = Deque[PipelineCommand]() + results_queue = Deque[str]() + + for qname, query in setup_queries: + if qname == "prepare": + pgconn.send_prepare(qname.encode(), query.encode()) + else: + pgconn.send_query_params(query.encode(), None) + results_queue.append(qname) + + committed = False + synced = False + + while True: + if rows_to_send: + params = [f"{rows_to_send}".encode(), f"{1 << 62}".encode()] + commands.append(partial(pgconn.send_query_prepared, b"prepare", params)) + results_queue.append(f"row {rows_to_send}") + rows_to_send -= 1 + + elif not committed: + committed = True + commands.append(partial(pgconn.send_query_params, b"COMMIT", None)) + results_queue.append("commit") + + elif not synced: + + def sync() -> None: + pgconn.pipeline_sync() + logger.info("pipeline sync sent") + + synced = True + commands.append(sync) + results_queue.append("sync") + + else: + break + + try: + yield commands, results_queue + finally: + logger.debug("exit pipeline") + pgconn.exit_pipeline_mode() + + +def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn(Connection.connect().pgconn, logger) + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = waiting.wait( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> None: + pgconn = LoggingPGconn((await AsyncConnection.connect()).pgconn, logger) + + with prepare_pipeline_demo_pq(pgconn, rows_to_send, logger) as ( + commands, + results_queue, + ): + while results_queue: + fetched = await waiting.wait_async( + pipeline_communicate( + pgconn, # type: ignore[arg-type] + commands, + ), + pgconn.socket, + ) + assert not commands, commands + for results in fetched: + results_queue.popleft() + for r in results: + if r.status in ( + pq.ExecStatus.FATAL_ERROR, + pq.ExecStatus.PIPELINE_ABORTED, + ): + raise e.error_from_result(r) + + +def pipeline_demo(rows_to_send: int, many: bool, logger: logging.Logger) -> None: + """Pipeline demo using sync API.""" + conn = Connection.connect() + conn.autocommit = True + conn.pgconn = LoggingPGconn(conn.pgconn, logger) # type: ignore[assignment] + with conn.pipeline(): + with conn.transaction(): + conn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + conn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = conn.cursor() + cur.executemany(query, list(params)) + else: + for p in params: + conn.execute(query, p) + + +async def pipeline_demo_async( + rows_to_send: int, many: bool, logger: logging.Logger +) -> None: + """Pipeline demo using async API.""" + aconn = await AsyncConnection.connect() + await aconn.set_autocommit(True) + aconn.pgconn = LoggingPGconn(aconn.pgconn, logger) # type: ignore[assignment] + async with aconn.pipeline(): + async with aconn.transaction(): + await aconn.execute("DROP TABLE IF EXISTS pq_pipeline_demo") + await aconn.execute( + "CREATE UNLOGGED TABLE pq_pipeline_demo(" + " id serial primary key," + " itemno integer," + " int8filler int8" + ")" + ) + query = "INSERT INTO pq_pipeline_demo(itemno, int8filler) VALUES (%s, %s)" + params = ((r, 1 << 62) for r in range(rows_to_send, 0, -1)) + if many: + cur = aconn.cursor() + await cur.executemany(query, list(params)) + else: + for p in params: + await aconn.execute(query, p) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", + dest="nrows", + metavar="ROWS", + default=10_000, + type=int, + help="number of rows to insert", + ) + parser.add_argument( + "--pq", action="store_true", help="use low-level psycopg.pq API" + ) + parser.add_argument( + "--async", dest="async_", action="store_true", help="use async API" + ) + parser.add_argument( + "--many", + action="store_true", + help="use executemany() (not applicable for --pq)", + ) + parser.add_argument("--trace", help="write trace info into TRACE file") + parser.add_argument("-l", "--log", help="log file (stderr by default)") + + global args + args = parser.parse_args() + + psycopg_logger.setLevel(logging.DEBUG) + pipeline_logger.setLevel(logging.DEBUG) + if args.log: + psycopg_logger.addHandler(logging.FileHandler(args.log)) + pipeline_logger.addHandler(logging.FileHandler(args.log)) + else: + psycopg_logger.addHandler(logging.StreamHandler()) + pipeline_logger.addHandler(logging.StreamHandler()) + + if args.pq: + if args.many: + parser.error("--many cannot be used with --pq") + if args.async_: + asyncio.run(pipeline_demo_pq_async(args.nrows, pipeline_logger)) + else: + pipeline_demo_pq(args.nrows, pipeline_logger) + else: + if pq.__impl__ != "python": + parser.error( + "only supported for Python implementation (set PSYCOPG_IMPL=python)" + ) + if args.async_: + asyncio.run(pipeline_demo_async(args.nrows, args.many, pipeline_logger)) + else: + pipeline_demo(args.nrows, args.many, pipeline_logger) + + +if __name__ == "__main__": + main() |