summaryrefslogtreecommitdiffstats
path: root/psycopg/psycopg/_pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'psycopg/psycopg/_pipeline.py')
-rw-r--r--psycopg/psycopg/_pipeline.py288
1 files changed, 288 insertions, 0 deletions
diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py
new file mode 100644
index 0000000..c818d86
--- /dev/null
+++ b/psycopg/psycopg/_pipeline.py
@@ -0,0 +1,288 @@
+"""
+commands pipeline management
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+from types import TracebackType
+from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from .abc import PipelineCommand, PQGen
+from ._compat import Deque
+from ._encodings import pgconn_encoding
+from ._preparing import Key, Prepare
+from .generators import pipeline_communicate, fetch_many, send
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+ from .cursor import BaseCursor
+ from .connection import BaseConnection, Connection
+ from .connection_async import AsyncConnection
+
+
+PendingResult: TypeAlias = Union[
+ None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
+]
+
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+BAD = pq.ConnStatus.BAD
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+logger = logging.getLogger("psycopg")
+
+
+class BasePipeline:
+
+ command_queue: Deque[PipelineCommand]
+ result_queue: Deque[PendingResult]
+ _is_supported: Optional[bool] = None
+
+ def __init__(self, conn: "BaseConnection[Any]") -> None:
+ self._conn = conn
+ self.pgconn = conn.pgconn
+ self.command_queue = Deque[PipelineCommand]()
+ self.result_queue = Deque[PendingResult]()
+ self.level = 0
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._conn.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def status(self) -> pq.PipelineStatus:
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ @classmethod
+ def is_supported(cls) -> bool:
+ """Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
+ if BasePipeline._is_supported is None:
+ BasePipeline._is_supported = not cls._not_supported_reason()
+ return BasePipeline._is_supported
+
+ @classmethod
+ def _not_supported_reason(cls) -> str:
+ """Return the reason why the pipeline mode is not supported.
+
+ Return an empty string if pipeline mode is supported.
+ """
+ # Support only depends on the libpq functions available in the pq
+ # wrapper, not on the database version.
+ if pq.version() < 140000:
+ return (
+ f"libpq too old {pq.version()};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ if pq.__build_version__ < 140000:
+ return (
+ f"libpq too old: module built for {pq.__build_version__};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ return ""
+
+ def _enter_gen(self) -> PQGen[None]:
+ if not self.is_supported():
+ raise e.NotSupportedError(
+ f"pipeline mode not supported: {self._not_supported_reason()}"
+ )
+ if self.level == 0:
+ self.pgconn.enter_pipeline_mode()
+ elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
+ # Nested pipeline case.
+ # Transaction might be ACTIVE when the pipeline uses an "implicit
+ # transaction", typically in autocommit mode. But when entering a
+ # Psycopg transaction(), we expect the IDLE state. By sync()-ing,
+ # we make sure all previous commands are completed and the
+ # transaction gets back to IDLE.
+ yield from self._sync_gen()
+ self.level += 1
+
+ def _exit(self, exc: Optional[BaseException]) -> None:
+ self.level -= 1
+ if self.level == 0 and self.pgconn.status != BAD:
+ try:
+ self.pgconn.exit_pipeline_mode()
+ except e.OperationalError as exc2:
+ # Notice that this error might be pretty irrecoverable. It
+ # happens on COPY, for instance: even if sync succeeds, exiting
+ # fails with "cannot exit pipeline mode with uncollected results"
+ if exc:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+
+ def _sync_gen(self) -> PQGen[None]:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ yield from self._fetch_gen(flush=False)
+
+ def _exit_gen(self) -> PQGen[None]:
+ """
+ Exit current pipeline by sending a Sync and fetch back all remaining results.
+ """
+ try:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ finally:
+ # No need to force flush since we emitted a sync just before.
+ yield from self._fetch_gen(flush=False)
+
+ def _communicate_gen(self) -> PQGen[None]:
+ """Communicate with pipeline to send commands and possibly fetch
+ results, which are then processed.
+ """
+ fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+ to_process = [(self.result_queue.popleft(), results) for results in fetched]
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+ """Fetch available results from the connection and process them with
+ pipeline queued items.
+
+ If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+ sure results can be fetched. Otherwise, the caller may emit a
+ PQpipelineSync() call to ensure the output buffer gets flushed before
+ fetching.
+ """
+ if not self.result_queue:
+ return
+
+ if flush:
+ self.pgconn.send_flush_request()
+ yield from send(self.pgconn)
+
+ to_process = []
+ while self.result_queue:
+ results = yield from fetch_many(self.pgconn)
+ if not results:
+ # No more results to fetch, but there may still be pending
+ # commands.
+ break
+ queued = self.result_queue.popleft()
+ to_process.append((queued, results))
+
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _process_results(
+ self, queued: PendingResult, results: List["PGresult"]
+ ) -> None:
+ """Process a results set fetched from the current pipeline.
+
+ This matches 'results' with its respective element in the pipeline
+ queue. For commands (None value in the pipeline queue), results are
+ checked directly. For prepare statement creation requests, update the
+ cache. Otherwise, results are attached to their respective cursor.
+ """
+ if queued is None:
+ (result,) = results
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ elif result.status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ cursor, prepinfo = queued
+ cursor._set_results_from_pipeline(results)
+ if prepinfo:
+ key, prep, name = prepinfo
+ # Update the prepare state of the query.
+ cursor._conn._prepared.validate(key, prep, name, results)
+
+ def _enqueue_sync(self) -> None:
+ """Enqueue a PQpipelineSync() command."""
+ self.command_queue.append(self.pgconn.pipeline_sync)
+ self.result_queue.append(None)
+
+
+class Pipeline(BasePipeline):
+ """Handler for connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "Connection[Any]"
+ _Self = TypeVar("_Self", bound="Pipeline")
+
+ def __init__(self, conn: "Connection[Any]") -> None:
+ super().__init__(conn)
+
+ def sync(self) -> None:
+ """Sync the pipeline, send any pending command and receive and process
+ all available results.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
+
+
+class AsyncPipeline(BasePipeline):
+ """Handler for async connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "AsyncConnection[Any]"
+ _Self = TypeVar("_Self", bound="AsyncPipeline")
+
+ def __init__(self, conn: "AsyncConnection[Any]") -> None:
+ super().__init__(conn)
+
+ async def sync(self) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)