summaryrefslogtreecommitdiffstats
path: root/psycopg
diff options
context:
space:
mode:
Diffstat (limited to 'psycopg')
-rw-r--r--psycopg/.flake86
-rw-r--r--psycopg/LICENSE.txt165
-rw-r--r--psycopg/README.rst31
-rw-r--r--psycopg/psycopg/__init__.py110
-rw-r--r--psycopg/psycopg/_adapters_map.py289
-rw-r--r--psycopg/psycopg/_cmodule.py24
-rw-r--r--psycopg/psycopg/_column.py143
-rw-r--r--psycopg/psycopg/_compat.py72
-rw-r--r--psycopg/psycopg/_dns.py223
-rw-r--r--psycopg/psycopg/_encodings.py170
-rw-r--r--psycopg/psycopg/_enums.py79
-rw-r--r--psycopg/psycopg/_pipeline.py288
-rw-r--r--psycopg/psycopg/_preparing.py198
-rw-r--r--psycopg/psycopg/_queries.py375
-rw-r--r--psycopg/psycopg/_struct.py57
-rw-r--r--psycopg/psycopg/_tpc.py116
-rw-r--r--psycopg/psycopg/_transform.py350
-rw-r--r--psycopg/psycopg/_typeinfo.py461
-rw-r--r--psycopg/psycopg/_tz.py44
-rw-r--r--psycopg/psycopg/_wrappers.py137
-rw-r--r--psycopg/psycopg/abc.py266
-rw-r--r--psycopg/psycopg/adapt.py162
-rw-r--r--psycopg/psycopg/client_cursor.py95
-rw-r--r--psycopg/psycopg/connection.py1031
-rw-r--r--psycopg/psycopg/connection_async.py436
-rw-r--r--psycopg/psycopg/conninfo.py378
-rw-r--r--psycopg/psycopg/copy.py904
-rw-r--r--psycopg/psycopg/crdb/__init__.py19
-rw-r--r--psycopg/psycopg/crdb/_types.py163
-rw-r--r--psycopg/psycopg/crdb/connection.py186
-rw-r--r--psycopg/psycopg/cursor.py921
-rw-r--r--psycopg/psycopg/cursor_async.py250
-rw-r--r--psycopg/psycopg/dbapi20.py112
-rw-r--r--psycopg/psycopg/errors.py1535
-rw-r--r--psycopg/psycopg/generators.py320
-rw-r--r--psycopg/psycopg/postgres.py125
-rw-r--r--psycopg/psycopg/pq/__init__.py133
-rw-r--r--psycopg/psycopg/pq/_debug.py106
-rw-r--r--psycopg/psycopg/pq/_enums.py249
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.py804
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.pyi216
-rw-r--r--psycopg/psycopg/pq/abc.py385
-rw-r--r--psycopg/psycopg/pq/misc.py146
-rw-r--r--psycopg/psycopg/pq/pq_ctypes.py1086
-rw-r--r--psycopg/psycopg/py.typed0
-rw-r--r--psycopg/psycopg/rows.py256
-rw-r--r--psycopg/psycopg/server_cursor.py479
-rw-r--r--psycopg/psycopg/sql.py467
-rw-r--r--psycopg/psycopg/transaction.py290
-rw-r--r--psycopg/psycopg/types/__init__.py11
-rw-r--r--psycopg/psycopg/types/array.py464
-rw-r--r--psycopg/psycopg/types/bool.py51
-rw-r--r--psycopg/psycopg/types/composite.py290
-rw-r--r--psycopg/psycopg/types/datetime.py754
-rw-r--r--psycopg/psycopg/types/enum.py177
-rw-r--r--psycopg/psycopg/types/hstore.py131
-rw-r--r--psycopg/psycopg/types/json.py232
-rw-r--r--psycopg/psycopg/types/multirange.py514
-rw-r--r--psycopg/psycopg/types/net.py206
-rw-r--r--psycopg/psycopg/types/none.py25
-rw-r--r--psycopg/psycopg/types/numeric.py515
-rw-r--r--psycopg/psycopg/types/range.py700
-rw-r--r--psycopg/psycopg/types/shapely.py75
-rw-r--r--psycopg/psycopg/types/string.py239
-rw-r--r--psycopg/psycopg/types/uuid.py65
-rw-r--r--psycopg/psycopg/version.py14
-rw-r--r--psycopg/psycopg/waiting.py331
-rw-r--r--psycopg/pyproject.toml3
-rw-r--r--psycopg/setup.cfg47
-rw-r--r--psycopg/setup.py66
70 files changed, 19768 insertions, 0 deletions
diff --git a/psycopg/.flake8 b/psycopg/.flake8
new file mode 100644
index 0000000..67fb024
--- /dev/null
+++ b/psycopg/.flake8
@@ -0,0 +1,6 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
+per-file-ignores =
+ # Autogenerated section
+ psycopg/errors.py: E125, E128, E302
diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg/README.rst b/psycopg/README.rst
new file mode 100644
index 0000000..45eeac3
--- /dev/null
+++ b/psycopg/README.rst
@@ -0,0 +1,31 @@
+Psycopg 3: PostgreSQL database adapter for Python
+=================================================
+
+Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python.
+
+This distribution contains the pure Python package ``psycopg``.
+
+
+Installation
+------------
+
+In short, run the following::
+
+ pip install --upgrade pip # to upgrade pip
+ pip install "psycopg[binary,pool]" # to install package and dependencies
+
+If something goes wrong, and for more information about installation, please
+check out the `Installation documentation`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html#
+
+
+Hacking
+-------
+
+For development information check out `the project readme`__.
+
+.. __: https://github.com/psycopg/psycopg#readme
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py
new file mode 100644
index 0000000..baadf30
--- /dev/null
+++ b/psycopg/psycopg/__init__.py
@@ -0,0 +1,110 @@
+"""
+psycopg -- PostgreSQL database adapter for Python
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from . import pq # noqa: F401 import early to stabilize side effects
+from . import types
+from . import postgres
+from ._tpc import Xid
+from .copy import Copy, AsyncCopy
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from .errors import Warning, Error, InterfaceError, DatabaseError
+from .errors import DataError, OperationalError, IntegrityError
+from .errors import InternalError, ProgrammingError, NotSupportedError
+from ._column import Column
+from .conninfo import ConnectionInfo
+from ._pipeline import Pipeline, AsyncPipeline
+from .connection import BaseConnection, Connection, Notify
+from .transaction import Rollback, Transaction, AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor, ServerCursor
+from .client_cursor import AsyncClientCursor, ClientCursor
+from .connection_async import AsyncConnection
+
+from . import dbapi20
+from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
+from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
+from .dbapi20 import Timestamp, TimestampFromTicks
+
+from .version import __version__ as __version__ # noqa: F401
+
+# Set the logger to a quiet default, can be enabled if needed
+logger = logging.getLogger("psycopg")
+if logger.level == logging.NOTSET:
+ logger.setLevel(logging.WARNING)
+
+# DBAPI compliance
+connect = Connection.connect
+apilevel = "2.0"
+threadsafety = 2
+paramstyle = "pyformat"
+
+# register default adapters for PostgreSQL
+adapters = postgres.adapters # exposed by the package
+postgres.register_default_adapters(adapters)
+
+# After the default ones, because these can deal with the bytea oid better
+dbapi20.register_dbapi20_adapters(adapters)
+
+# Must come after all the types have been registered
+types.array.register_all_arrays(adapters)
+
+# Note: defining the exported methods helps both Sphynx in documenting that
+# this is the canonical place to obtain them and should be used by MyPy too,
+# so that function signatures are consistent with the documentation.
+__all__ = [
+ "AsyncClientCursor",
+ "AsyncConnection",
+ "AsyncCopy",
+ "AsyncCursor",
+ "AsyncPipeline",
+ "AsyncServerCursor",
+ "AsyncTransaction",
+ "BaseConnection",
+ "ClientCursor",
+ "Column",
+ "Connection",
+ "ConnectionInfo",
+ "Copy",
+ "Cursor",
+ "IsolationLevel",
+ "Notify",
+ "Pipeline",
+ "Rollback",
+ "ServerCursor",
+ "Transaction",
+ "Xid",
+ # DBAPI exports
+ "connect",
+ "apilevel",
+ "threadsafety",
+ "paramstyle",
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ # DBAPI type constructors and singletons
+ "Binary",
+ "Date",
+ "DateFromTicks",
+ "Time",
+ "TimeFromTicks",
+ "Timestamp",
+ "TimestampFromTicks",
+ "BINARY",
+ "DATETIME",
+ "NUMBER",
+ "ROWID",
+ "STRING",
+]
diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py
new file mode 100644
index 0000000..a3a6ef8
--- /dev/null
+++ b/psycopg/psycopg/_adapters_map.py
@@ -0,0 +1,289 @@
+"""
+Mapping from types/oids to Dumpers/Loaders
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+from typing import cast, TYPE_CHECKING
+
+from . import pq
+from . import errors as e
+from .abc import Dumper, Loader
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+from ._typeinfo import TypesRegistry
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+RV = TypeVar("RV")
+
+
+class AdaptersMap:
+ r"""
+ Establish how types should be converted between Python and PostgreSQL in
+ an `~psycopg.abc.AdaptContext`.
+
+ `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
+ define how Python types are converted to PostgreSQL, and maps OIDs to
+ `~psycopg.adapt.Loader` classes to establish how query results are
+ converted to Python.
+
+ Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
+ types are converted in that context, exposed as the
+ `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
+ to customise adaptation in a context without changing separated contexts.
+
+ When a context is created from another context (for instance when a
+ `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
+ `!adapters` are used as template for the child's `!adapters`, so that every
+ cursor created from the same connection use the connection's types
+ configuration, but separate connections have independent mappings.
+
+ Once created, `!AdaptersMap` are independent. This means that objects
+ already created are not affected if a wider scope (e.g. the global one) is
+ changed.
+
+ The connections adapters are initialised using a global `!AdptersMap`
+ template, exposed as `psycopg.adapters`: changing such mapping allows to
+ customise the type mapping for every connections created afterwards.
+
+ The object can start empty or copy from another object of the same class.
+ Copies are copy-on-write: if the maps are updated make a copy. This way
+ extending e.g. global map by a connection or a connection map from a cursor
+ is cheap: a copy is only made on customisation.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ types: TypesRegistry
+
+ _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
+ _dumpers_by_oid: List[Dict[int, Type[Dumper]]]
+ _loaders: List[Dict[int, Type[Loader]]]
+
+ # Record if a dumper or loader has an optimised version.
+ _optimised: Dict[type, type] = {}
+
+ def __init__(
+ self,
+ template: Optional["AdaptersMap"] = None,
+ types: Optional[TypesRegistry] = None,
+ ):
+ if template:
+ self._dumpers = template._dumpers.copy()
+ self._own_dumpers = _dumpers_shared.copy()
+ template._own_dumpers = _dumpers_shared.copy()
+
+ self._dumpers_by_oid = template._dumpers_by_oid[:]
+ self._own_dumpers_by_oid = [False, False]
+ template._own_dumpers_by_oid = [False, False]
+
+ self._loaders = template._loaders[:]
+ self._own_loaders = [False, False]
+ template._own_loaders = [False, False]
+
+ self.types = TypesRegistry(template.types)
+
+ else:
+ self._dumpers = {fmt: {} for fmt in PyFormat}
+ self._own_dumpers = _dumpers_owned.copy()
+
+ self._dumpers_by_oid = [{}, {}]
+ self._own_dumpers_by_oid = [True, True]
+
+ self._loaders = [{}, {}]
+ self._own_loaders = [True, True]
+
+ self.types = types or TypesRegistry()
+
+ # implement the AdaptContext protocol too
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return None
+
+ def register_dumper(
+ self, cls: Union[type, str, None], dumper: Type[Dumper]
+ ) -> None:
+ """
+ Configure the context to use `!dumper` to convert objects of type `!cls`.
+
+ If two dumpers with different `~Dumper.format` are registered for the
+ same type, the last one registered will be chosen when the query
+ doesn't specify a format (i.e. when the value is used with a ``%s``
+ "`~PyFormat.AUTO`" placeholder).
+
+ :param cls: The type to manage.
+ :param dumper: The dumper to register for `!cls`.
+
+ If `!cls` is specified as string it will be lazy-loaded, so that it
+ will be possible to register it without importing it before. In this
+ case it should be the fully qualified name of the object (e.g.
+ ``"uuid.UUID"``).
+
+ If `!cls` is None, only use the dumper when looking up using
+ `get_dumper_by_oid()`, which happens when we know the Postgres type to
+ adapt to, but not the Python type that will be adapted (e.g. in COPY
+ after using `~psycopg.Copy.set_types()`).
+
+ """
+ if not (cls is None or isinstance(cls, (str, type))):
+ raise TypeError(
+ f"dumpers should be registered on classes, got {cls} instead"
+ )
+
+ if _psycopg:
+ dumper = self._get_optimised(dumper)
+
+ # Register the dumper both as its format and as auto
+ # so that the last dumper registered is used in auto (%s) format
+ if cls:
+ for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
+ if not self._own_dumpers[fmt]:
+ self._dumpers[fmt] = self._dumpers[fmt].copy()
+ self._own_dumpers[fmt] = True
+
+ self._dumpers[fmt][cls] = dumper
+
+ # Register the dumper by oid, if the oid of the dumper is fixed
+ if dumper.oid:
+ if not self._own_dumpers_by_oid[dumper.format]:
+ self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
+ dumper.format
+ ].copy()
+ self._own_dumpers_by_oid[dumper.format] = True
+
+ self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
+
+ def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
+ """
+ Configure the context to use `!loader` to convert data of oid `!oid`.
+
+ :param oid: The PostgreSQL OID or type name to manage.
+ :param loader: The loar to register for `!oid`.
+
+ If `oid` is specified as string, it refers to a type name, which is
+ looked up in the `types` registry. `
+
+ """
+ if isinstance(oid, str):
+ oid = self.types[oid].oid
+ if not isinstance(oid, int):
+ raise TypeError(f"loaders should be registered on oid, got {oid} instead")
+
+ if _psycopg:
+ loader = self._get_optimised(loader)
+
+ fmt = loader.format
+ if not self._own_loaders[fmt]:
+ self._loaders[fmt] = self._loaders[fmt].copy()
+ self._own_loaders[fmt] = True
+
+ self._loaders[fmt][oid] = loader
+
+ def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given type and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param cls: The class to adapt.
+ :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
+ use the last one of the dumpers registered on `!cls`.
+ """
+ try:
+ dmap = self._dumpers[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ # Look for the right class, including looking at superclasses
+ for scls in cls.__mro__:
+ if scls in dmap:
+ return dmap[scls]
+
+ # If the adapter is not found, look for its name as a string
+ fqn = scls.__module__ + "." + scls.__qualname__
+ if fqn in dmap:
+ # Replace the class name with the class itself
+ d = dmap[scls] = dmap.pop(fqn)
+ return d
+
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'"
+ f" (format: {PyFormat(format).name})"
+ )
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given oid and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param oid: The oid of the type to dump to.
+ :param format: The format to dump to.
+ """
+ try:
+ dmap = self._dumpers_by_oid[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ try:
+ return dmap[oid]
+ except KeyError:
+ info = self.types.get(oid)
+ if info:
+ msg = (
+ f"cannot find a dumper for type {info.name} (oid {oid})"
+ f" format {pq.Format(format).name}"
+ )
+ else:
+ msg = (
+ f"cannot find a dumper for unknown type with oid {oid}"
+ f" format {pq.Format(format).name}"
+ )
+ raise e.ProgrammingError(msg)
+
+ def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
+ """
+ Return the loader class for the given oid and format.
+
+ Return `!None` if not found.
+
+ :param oid: The oid of the type to load.
+ :param format: The format to load from.
+ """
+ return self._loaders[format].get(oid)
+
+ @classmethod
+ def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
+ """Return the optimised version of a Dumper or Loader class.
+
+ Return the input class itself if there is no optimised version.
+ """
+ try:
+ return self._optimised[cls]
+ except KeyError:
+ pass
+
+ # Check if the class comes from psycopg.types and there is a class
+ # with the same name in psycopg_c._psycopg.
+ from psycopg import types
+
+ if cls.__module__.startswith(types.__name__):
+ new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
+ if new:
+ self._optimised[cls] = new
+ return new
+
+ self._optimised[cls] = cls
+ return cls
+
+
+# Micro-optimization: copying these objects is faster than creating new dicts
+_dumpers_owned = dict.fromkeys(PyFormat, True)
+_dumpers_shared = dict.fromkeys(PyFormat, False)
diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py
new file mode 100644
index 0000000..288ef1b
--- /dev/null
+++ b/psycopg/psycopg/_cmodule.py
@@ -0,0 +1,24 @@
+"""
+Simplify access to the _psycopg module
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from typing import Optional
+
+from . import pq
+
+__version__: Optional[str] = None
+
+# Note: "c" must the first attempt so that mypy associates the variable the
+# right module interface. It will not result Optional, but hey.
+if pq.__impl__ == "c":
+ from psycopg_c import _psycopg as _psycopg
+ from psycopg_c import __version__ as __version__ # noqa: F401
+elif pq.__impl__ == "binary":
+ from psycopg_binary import _psycopg as _psycopg # type: ignore
+ from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
+elif pq.__impl__ == "python":
+ _psycopg = None # type: ignore
+else:
+ raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py
new file mode 100644
index 0000000..9e4e735
--- /dev/null
+++ b/psycopg/psycopg/_column.py
@@ -0,0 +1,143 @@
+"""
+The Column object in Cursor.description
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
+from operator import attrgetter
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor
+
+
+class ColumnData(NamedTuple):
+ ftype: int
+ fmod: int
+ fsize: int
+
+
+class Column(Sequence[Any]):
+
+ __module__ = "psycopg"
+
+ def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
+ res = cursor.pgresult
+ assert res
+
+ fname = res.fname(index)
+ if fname:
+ self._name = fname.decode(cursor._encoding)
+ else:
+ # COPY_OUT results have columns but no name
+ self._name = f"column_{index + 1}"
+
+ self._data = ColumnData(
+ ftype=res.ftype(index),
+ fmod=res.fmod(index),
+ fsize=res.fsize(index),
+ )
+ self._type = cursor.adapters.types.get(self._data.ftype)
+
+ _attrs = tuple(
+ attrgetter(attr)
+ for attr in """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<Column {self.name!r},"
+ f" type: {self._type_display()} (oid: {self.type_code})>"
+ )
+
+ def __len__(self) -> int:
+ return 7
+
+ def _type_display(self) -> str:
+ parts = []
+ parts.append(self._type.name if self._type else str(self.type_code))
+
+ mod1 = self.precision
+ if mod1 is None:
+ mod1 = self.display_size
+ if mod1:
+ parts.append(f"({mod1}")
+ if self.scale:
+ parts.append(f", {self.scale}")
+ parts.append(")")
+
+ if self._type and self.type_code == self._type.array_oid:
+ parts.append("[]")
+
+ return "".join(parts)
+
+ def __getitem__(self, index: Any) -> Any:
+ if isinstance(index, slice):
+ return tuple(getter(self) for getter in self._attrs[index])
+ else:
+ return self._attrs[index](self)
+
+ @property
+ def name(self) -> str:
+ """The name of the column."""
+ return self._name
+
+ @property
+ def type_code(self) -> int:
+ """The numeric OID of the column."""
+ return self._data.ftype
+
+ @property
+ def display_size(self) -> Optional[int]:
+ """The field size, for :sql:`varchar(n)`, None otherwise."""
+ if not self._type:
+ return None
+
+ if self._type.name in ("varchar", "char"):
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod - 4
+
+ return None
+
+ @property
+ def internal_size(self) -> Optional[int]:
+ """The internal field size for fixed-size types, None otherwise."""
+ fsize = self._data.fsize
+ return fsize if fsize >= 0 else None
+
+ @property
+ def precision(self) -> Optional[int]:
+ """The number of digits for fixed precision types."""
+ if not self._type:
+ return None
+
+ dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
+ if self._type.name == "numeric":
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod >> 16
+
+ elif self._type.name in dttypes:
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def scale(self) -> Optional[int]:
+ """The number of digits after the decimal point if available."""
+ if self._type and self._type.name == "numeric":
+ fmod = self._data.fmod - 4
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def null_ok(self) -> Optional[bool]:
+ """Always `!None`"""
+ return None
diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py
new file mode 100644
index 0000000..7dbae79
--- /dev/null
+++ b/psycopg/psycopg/_compat.py
@@ -0,0 +1,72 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
+
+# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
+# For this raisin it must be imported directly from typing_extension where used.
+# See https://github.com/microsoft/pyright/issues/4197
+from typing_extensions import TypeAlias
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+T = TypeVar("T")
+FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 8):
+ create_task = asyncio.create_task
+ from math import prod
+
+else:
+
+ def create_task(
+ coro: FutureT[T], name: Optional[str] = None
+ ) -> "asyncio.Future[T]":
+ return asyncio.create_task(coro)
+
+ from functools import reduce
+
+ def prod(seq: Sequence[int]) -> int:
+ return reduce(int.__mul__, seq, 1)
+
+
+if sys.version_info >= (3, 9):
+ from zoneinfo import ZoneInfo
+ from functools import cache
+ from collections import Counter, deque as Deque
+else:
+ from typing import Counter, Deque
+ from functools import lru_cache
+ from backports.zoneinfo import ZoneInfo
+
+ cache = lru_cache(maxsize=None)
+
+if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+else:
+ from typing_extensions import TypeGuard
+
+if sys.version_info >= (3, 11):
+ from typing import LiteralString
+else:
+ from typing_extensions import LiteralString
+
+__all__ = [
+ "Counter",
+ "Deque",
+ "LiteralString",
+ "Protocol",
+ "TypeGuard",
+ "ZoneInfo",
+ "cache",
+ "create_task",
+ "prod",
+]
diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py
new file mode 100644
index 0000000..1e146ba
--- /dev/null
+++ b/psycopg/psycopg/_dns.py
@@ -0,0 +1,223 @@
+# type: ignore # dnspython is currently optional and mypy fails if missing
+"""
+DNS query support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import os
+import re
+import warnings
+from random import randint
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
+from typing import TYPE_CHECKING
+from collections import defaultdict
+
+try:
+ from dns.resolver import Resolver, Cache
+ from dns.asyncresolver import Resolver as AsyncResolver
+ from dns.exception import DNSException
+except ImportError:
+ raise ImportError(
+ "the module psycopg._dns requires the package 'dnspython' installed"
+ )
+
+from . import errors as e
+from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_
+
+if TYPE_CHECKING:
+ from dns.rdtypes.IN.SRV import SRV
+
+resolver = Resolver()
+resolver.cache = Cache()
+
+async_resolver = AsyncResolver()
+async_resolver.cache = Cache()
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ .. deprecated:: 3.1
+ The use of this function is not necessary anymore, because
+ `psycopg.AsyncConnection.connect()` performs non-blocking name
+ resolution automatically.
+ """
+ warnings.warn(
+ "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
+ DeprecationWarning,
+ )
+ return await resolve_hostaddr_async_(params)
+
+
+def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply SRV DNS lookup as defined in :RFC:`2782`."""
+ return Rfc2782Resolver().resolve(params)
+
+
+async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Async equivalent of `resolve_srv()`."""
+ return await Rfc2782Resolver().resolve_async(params)
+
+
+class HostPort(NamedTuple):
+ host: str
+ port: str
+ totry: bool = False
+ target: Optional[str] = None
+
+
+class Rfc2782Resolver:
+ """Implement SRV RR Resolution as per RFC 2782
+
+ The class is organised to minimise code duplication between the sync and
+ the async paths.
+ """
+
+ re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
+
+ def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(self._resolve_srv(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(await self._resolve_srv_async(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
+ """
+ Return the list of host, and for each host if SRV lookup must be tried.
+
+ Return an empty list if no lookup is requested.
+ """
+ # If hostaddr is defined don't do any resolution.
+ if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
+ return []
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",")
+
+ if len(ports_in) == 1:
+ # If only one port is specified, it applies to all the hosts.
+ ports_in *= len(hosts_in)
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+
+ out = []
+ srv_found = False
+ for host, port in zip(hosts_in, ports_in):
+ m = self.re_srv_rr.match(host)
+ if m or port.lower() == "srv":
+ srv_found = True
+ target = m.group("target") if m else None
+ hp = HostPort(host=host, port=port, totry=True, target=target)
+ else:
+ hp = HostPort(host=host, port=port)
+ out.append(hp)
+
+ return out if srv_found else []
+
+ def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = await async_resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ def _get_solved_entries(
+ self, hp: HostPort, entries: "Sequence[SRV]"
+ ) -> List[HostPort]:
+ if not entries:
+ # No SRV entry found. Delegate the libpq a QNAME=target lookup
+ if hp.target and hp.port.lower() != "srv":
+ return [HostPort(host=hp.target, port=hp.port)]
+ else:
+ return []
+
+ # If there is precisely one SRV RR, and its Target is "." (the root
+ # domain), abort.
+ if len(entries) == 1 and str(entries[0].target) == ".":
+ return []
+
+ return [
+ HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
+ for entry in self.sort_rfc2782(entries)
+ ]
+
+ def _return_params(
+ self, params: Dict[str, Any], hps: List[HostPort]
+ ) -> Dict[str, Any]:
+ if not hps:
+ # Nothing found, we ended up with an empty list
+ raise e.OperationalError("no host found after SRV RR lookup")
+
+ out = params.copy()
+ out["host"] = ",".join(hp.host for hp in hps)
+ out["port"] = ",".join(str(hp.port) for hp in hps)
+ return out
+
+ def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
+ """
+ Implement the priority/weight ordering defined in RFC 2782.
+ """
+ # Divide the entries by priority:
+ priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
+ out: "List[SRV]" = []
+ for entry in ans:
+ priorities[entry.priority].append(entry)
+
+ for pri, entries in sorted(priorities.items()):
+ if len(entries) == 1:
+ out.append(entries[0])
+ continue
+
+ entries.sort(key=lambda ent: ent.weight)
+ total_weight = sum(ent.weight for ent in entries)
+ while entries:
+ r = randint(0, total_weight)
+ csum = 0
+ for i, ent in enumerate(entries):
+ csum += ent.weight
+ if csum >= r:
+ break
+ out.append(ent)
+ total_weight -= ent.weight
+ del entries[i]
+
+ return out
diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py
new file mode 100644
index 0000000..c584b26
--- /dev/null
+++ b/psycopg/psycopg/_encodings.py
@@ -0,0 +1,170 @@
+"""
+Mappings between PostgreSQL and Python encodings.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import string
+import codecs
+from typing import Any, Dict, Optional, TYPE_CHECKING
+
+from .pq._enums import ConnStatus
+from .errors import NotSupportedError
+from ._compat import cache
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+ from .connection import BaseConnection
+
+OK = ConnStatus.OK
+
+
+_py_codecs = {
+ "BIG5": "big5",
+ "EUC_CN": "gb2312",
+ "EUC_JIS_2004": "euc_jis_2004",
+ "EUC_JP": "euc_jp",
+ "EUC_KR": "euc_kr",
+ # "EUC_TW": not available in Python
+ "GB18030": "gb18030",
+ "GBK": "gbk",
+ "ISO_8859_5": "iso8859-5",
+ "ISO_8859_6": "iso8859-6",
+ "ISO_8859_7": "iso8859-7",
+ "ISO_8859_8": "iso8859-8",
+ "JOHAB": "johab",
+ "KOI8R": "koi8-r",
+ "KOI8U": "koi8-u",
+ "LATIN1": "iso8859-1",
+ "LATIN10": "iso8859-16",
+ "LATIN2": "iso8859-2",
+ "LATIN3": "iso8859-3",
+ "LATIN4": "iso8859-4",
+ "LATIN5": "iso8859-9",
+ "LATIN6": "iso8859-10",
+ "LATIN7": "iso8859-13",
+ "LATIN8": "iso8859-14",
+ "LATIN9": "iso8859-15",
+ # "MULE_INTERNAL": not available in Python
+ "SHIFT_JIS_2004": "shift_jis_2004",
+ "SJIS": "shift_jis",
+ # this actually means no encoding, see PostgreSQL docs
+ # it is special-cased by the text loader.
+ "SQL_ASCII": "ascii",
+ "UHC": "cp949",
+ "UTF8": "utf-8",
+ "WIN1250": "cp1250",
+ "WIN1251": "cp1251",
+ "WIN1252": "cp1252",
+ "WIN1253": "cp1253",
+ "WIN1254": "cp1254",
+ "WIN1255": "cp1255",
+ "WIN1256": "cp1256",
+ "WIN1257": "cp1257",
+ "WIN1258": "cp1258",
+ "WIN866": "cp866",
+ "WIN874": "cp874",
+}
+
+py_codecs: Dict[bytes, str] = {}
+py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
+
+# Add an alias without underscore, for lenient lookups
+py_codecs.update(
+ (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
+)
+
+pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
+
+
+def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
+ """
+ Return the Python encoding name of a psycopg connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if not conn or conn.closed:
+ return "utf-8"
+
+ pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def pgconn_encoding(pgconn: "PGconn") -> str:
+ """
+ Return the Python encoding name of a libpq connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if pgconn.status != OK:
+ return "utf-8"
+
+ pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def conninfo_encoding(conninfo: str) -> str:
+ """
+ Return the Python encoding name passed in a conninfo string. Default to utf8.
+
+ Because the input is likely to come from the user and not normalised by the
+ server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
+ """
+ from .conninfo import conninfo_to_dict
+
+ params = conninfo_to_dict(conninfo)
+ pgenc = params.get("client_encoding")
+ if pgenc:
+ try:
+ return pg2pyenc(pgenc.encode())
+ except NotSupportedError:
+ pass
+
+ return "utf-8"
+
+
+@cache
+def py2pgenc(name: str) -> bytes:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise LookupError if the Python encoding is unknown.
+ """
+ return pg_codecs[codecs.lookup(name).name]
+
+
+@cache
+def pg2pyenc(name: bytes) -> str:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise NotSupportedError if the PostgreSQL encoding is not supported by
+ Python.
+ """
+ try:
+ return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
+ except KeyError:
+ sname = name.decode("utf8", "replace")
+ raise NotSupportedError(f"codec not available in Python: {sname!r}")
+
+
+def _as_python_identifier(s: str, prefix: str = "f") -> str:
+ """
+ Reduce a string to a valid Python identifier.
+
+ Replace all non-valid chars with '_' and prefix the value with `!prefix` if
+ the first letter is an '_'.
+ """
+ if not s.isidentifier():
+ if s[0] in "1234567890":
+ s = prefix + s
+ if not s.isidentifier():
+ s = _re_clean.sub("_", s)
+ # namedtuple fields cannot start with underscore. So...
+ if s[0] == "_":
+ s = prefix + s
+ return s
+
+
+_re_clean = re.compile(
+ f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
+)
diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py
new file mode 100644
index 0000000..a7cb78d
--- /dev/null
+++ b/psycopg/psycopg/_enums.py
@@ -0,0 +1,79 @@
+"""
+Enum values for psycopg
+
+These values are defined by us and are not necessarily dependent on
+libpq-defined enums.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import Enum, IntEnum
+from selectors import EVENT_READ, EVENT_WRITE
+
+from . import pq
+
+
+class Wait(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class Ready(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class PyFormat(str, Enum):
+ """
+ Enum representing the format wanted for a query argument.
+
+ The value `AUTO` allows psycopg to choose the best format for a certain
+ parameter.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ AUTO = "s"
+ """Automatically chosen (``%s`` placeholder)."""
+ TEXT = "t"
+ """Text parameter (``%t`` placeholder)."""
+ BINARY = "b"
+ """Binary parameter (``%b`` placeholder)."""
+
+ @classmethod
+ def from_pq(cls, fmt: pq.Format) -> "PyFormat":
+ return _pg2py[fmt]
+
+ @classmethod
+ def as_pq(cls, fmt: "PyFormat") -> pq.Format:
+ return _py2pg[fmt]
+
+
+class IsolationLevel(IntEnum):
+ """
+ Enum representing the isolation level for a transaction.
+ """
+
+ __module__ = "psycopg"
+
+ READ_UNCOMMITTED = 1
+ """:sql:`READ UNCOMMITTED` isolation level."""
+ READ_COMMITTED = 2
+ """:sql:`READ COMMITTED` isolation level."""
+ REPEATABLE_READ = 3
+ """:sql:`REPEATABLE READ` isolation level."""
+ SERIALIZABLE = 4
+ """:sql:`SERIALIZABLE` isolation level."""
+
+
+_py2pg = {
+ PyFormat.TEXT: pq.Format.TEXT,
+ PyFormat.BINARY: pq.Format.BINARY,
+}
+
+_pg2py = {
+ pq.Format.TEXT: PyFormat.TEXT,
+ pq.Format.BINARY: PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py
new file mode 100644
index 0000000..c818d86
--- /dev/null
+++ b/psycopg/psycopg/_pipeline.py
@@ -0,0 +1,288 @@
+"""
+commands pipeline management
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+from types import TracebackType
+from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from .abc import PipelineCommand, PQGen
+from ._compat import Deque
+from ._encodings import pgconn_encoding
+from ._preparing import Key, Prepare
+from .generators import pipeline_communicate, fetch_many, send
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+ from .cursor import BaseCursor
+ from .connection import BaseConnection, Connection
+ from .connection_async import AsyncConnection
+
+
+PendingResult: TypeAlias = Union[
+ None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
+]
+
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+BAD = pq.ConnStatus.BAD
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+logger = logging.getLogger("psycopg")
+
+
+class BasePipeline:
+
+ command_queue: Deque[PipelineCommand]
+ result_queue: Deque[PendingResult]
+ _is_supported: Optional[bool] = None
+
+ def __init__(self, conn: "BaseConnection[Any]") -> None:
+ self._conn = conn
+ self.pgconn = conn.pgconn
+ self.command_queue = Deque[PipelineCommand]()
+ self.result_queue = Deque[PendingResult]()
+ self.level = 0
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._conn.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def status(self) -> pq.PipelineStatus:
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ @classmethod
+ def is_supported(cls) -> bool:
+ """Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
+ if BasePipeline._is_supported is None:
+ BasePipeline._is_supported = not cls._not_supported_reason()
+ return BasePipeline._is_supported
+
+ @classmethod
+ def _not_supported_reason(cls) -> str:
+ """Return the reason why the pipeline mode is not supported.
+
+ Return an empty string if pipeline mode is supported.
+ """
+ # Support only depends on the libpq functions available in the pq
+ # wrapper, not on the database version.
+ if pq.version() < 140000:
+ return (
+ f"libpq too old {pq.version()};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ if pq.__build_version__ < 140000:
+ return (
+ f"libpq too old: module built for {pq.__build_version__};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ return ""
+
+ def _enter_gen(self) -> PQGen[None]:
+ if not self.is_supported():
+ raise e.NotSupportedError(
+ f"pipeline mode not supported: {self._not_supported_reason()}"
+ )
+ if self.level == 0:
+ self.pgconn.enter_pipeline_mode()
+ elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
+ # Nested pipeline case.
+ # Transaction might be ACTIVE when the pipeline uses an "implicit
+ # transaction", typically in autocommit mode. But when entering a
+ # Psycopg transaction(), we expect the IDLE state. By sync()-ing,
+ # we make sure all previous commands are completed and the
+ # transaction gets back to IDLE.
+ yield from self._sync_gen()
+ self.level += 1
+
+ def _exit(self, exc: Optional[BaseException]) -> None:
+ self.level -= 1
+ if self.level == 0 and self.pgconn.status != BAD:
+ try:
+ self.pgconn.exit_pipeline_mode()
+ except e.OperationalError as exc2:
+ # Notice that this error might be pretty irrecoverable. It
+ # happens on COPY, for instance: even if sync succeeds, exiting
+ # fails with "cannot exit pipeline mode with uncollected results"
+ if exc:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+
+ def _sync_gen(self) -> PQGen[None]:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ yield from self._fetch_gen(flush=False)
+
+ def _exit_gen(self) -> PQGen[None]:
+ """
+ Exit current pipeline by sending a Sync and fetch back all remaining results.
+ """
+ try:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ finally:
+ # No need to force flush since we emitted a sync just before.
+ yield from self._fetch_gen(flush=False)
+
+ def _communicate_gen(self) -> PQGen[None]:
+ """Communicate with pipeline to send commands and possibly fetch
+ results, which are then processed.
+ """
+ fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+ to_process = [(self.result_queue.popleft(), results) for results in fetched]
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+ """Fetch available results from the connection and process them with
+ pipeline queued items.
+
+ If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+ sure results can be fetched. Otherwise, the caller may emit a
+ PQpipelineSync() call to ensure the output buffer gets flushed before
+ fetching.
+ """
+ if not self.result_queue:
+ return
+
+ if flush:
+ self.pgconn.send_flush_request()
+ yield from send(self.pgconn)
+
+ to_process = []
+ while self.result_queue:
+ results = yield from fetch_many(self.pgconn)
+ if not results:
+ # No more results to fetch, but there may still be pending
+ # commands.
+ break
+ queued = self.result_queue.popleft()
+ to_process.append((queued, results))
+
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _process_results(
+ self, queued: PendingResult, results: List["PGresult"]
+ ) -> None:
+ """Process a results set fetched from the current pipeline.
+
+ This matches 'results' with its respective element in the pipeline
+ queue. For commands (None value in the pipeline queue), results are
+ checked directly. For prepare statement creation requests, update the
+ cache. Otherwise, results are attached to their respective cursor.
+ """
+ if queued is None:
+ (result,) = results
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ elif result.status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ cursor, prepinfo = queued
+ cursor._set_results_from_pipeline(results)
+ if prepinfo:
+ key, prep, name = prepinfo
+ # Update the prepare state of the query.
+ cursor._conn._prepared.validate(key, prep, name, results)
+
+ def _enqueue_sync(self) -> None:
+ """Enqueue a PQpipelineSync() command."""
+ self.command_queue.append(self.pgconn.pipeline_sync)
+ self.result_queue.append(None)
+
+
+class Pipeline(BasePipeline):
+ """Handler for connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "Connection[Any]"
+ _Self = TypeVar("_Self", bound="Pipeline")
+
+ def __init__(self, conn: "Connection[Any]") -> None:
+ super().__init__(conn)
+
+ def sync(self) -> None:
+ """Sync the pipeline, send any pending command and receive and process
+ all available results.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
+
+
+class AsyncPipeline(BasePipeline):
+ """Handler for async connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "AsyncConnection[Any]"
+ _Self = TypeVar("_Self", bound="AsyncPipeline")
+
+ def __init__(self, conn: "AsyncConnection[Any]") -> None:
+ super().__init__(conn)
+
+ async def sync(self) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py
new file mode 100644
index 0000000..f60c0cb
--- /dev/null
+++ b/psycopg/psycopg/_preparing.py
@@ -0,0 +1,198 @@
+"""
+Support for prepared statements
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, auto
+from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
+from collections import OrderedDict
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._compat import Deque
+from ._queries import PostgresQuery
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+
+Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+
+class Prepare(IntEnum):
+ NO = auto()
+ YES = auto()
+ SHOULD = auto()
+
+
+class PrepareManager:
+ # Number of times a query is executed before it is prepared.
+ prepare_threshold: Optional[int] = 5
+
+ # Maximum number of prepared statements on the connection.
+ prepared_max: int = 100
+
+ def __init__(self) -> None:
+ # Map (query, types) to the number of times the query was seen.
+ self._counts: OrderedDict[Key, int] = OrderedDict()
+
+ # Map (query, types) to the name of the statement if prepared.
+ self._names: OrderedDict[Key, bytes] = OrderedDict()
+
+ # Counter to generate prepared statements names
+ self._prepared_idx = 0
+
+ self._maint_commands = Deque[bytes]()
+
+ @staticmethod
+ def key(query: PostgresQuery) -> Key:
+ return (query.query, query.types)
+
+ def get(
+ self, query: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ """
+ Check if a query is prepared, tell back whether to prepare it.
+ """
+ if prepare is False or self.prepare_threshold is None:
+ # The user doesn't want this query to be prepared
+ return Prepare.NO, b""
+
+ key = self.key(query)
+ name = self._names.get(key)
+ if name:
+ # The query was already prepared in this session
+ return Prepare.YES, name
+
+ count = self._counts.get(key, 0)
+ if count >= self.prepare_threshold or prepare:
+ # The query has been executed enough times and needs to be prepared
+ name = f"_pg3_{self._prepared_idx}".encode()
+ self._prepared_idx += 1
+ return Prepare.SHOULD, name
+ else:
+ # The query is not to be prepared yet
+ return Prepare.NO, b""
+
+ def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
+ """Check if we need to discard our entire state: it should happen on
+ rollback or on dropping objects, because the same object may get
+ recreated and postgres would fail internal lookups.
+ """
+ if self._names or prep == Prepare.SHOULD:
+ for result in results:
+ if result.status != COMMAND_OK:
+ continue
+ cmdstat = result.command_status
+ if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
+ return self.clear()
+ return False
+
+ @staticmethod
+ def _check_results(results: Sequence["PGresult"]) -> bool:
+ """Return False if 'results' are invalid for prepared statement cache."""
+ if len(results) != 1:
+ # We cannot prepare a multiple statement
+ return False
+
+ status = results[0].status
+ if COMMAND_OK != status != TUPLES_OK:
+ # We don't prepare failed queries or other weird results
+ return False
+
+ return True
+
+ def _rotate(self) -> None:
+ """Evict an old value from the cache.
+
+ If it was prepared, deallocate it. Do it only once: if the cache was
+ resized, deallocate gradually.
+ """
+ if len(self._counts) > self.prepared_max:
+ self._counts.popitem(last=False)
+
+ if len(self._names) > self.prepared_max:
+ name = self._names.popitem(last=False)[1]
+ self._maint_commands.append(b"DEALLOCATE " + name)
+
+ def maybe_add_to_cache(
+ self, query: PostgresQuery, prep: Prepare, name: bytes
+ ) -> Optional[Key]:
+ """Handle 'query' for possible addition to the cache.
+
+ If a new entry has been added, return its key. Return None otherwise
+ (meaning the query is already in cache or cache is not enabled).
+
+ Note: This method is only called in pipeline mode.
+ """
+ # don't do anything if prepared statements are disabled
+ if self.prepare_threshold is None:
+ return None
+
+ key = self.key(query)
+ if key in self._counts:
+ if prep is Prepare.SHOULD:
+ del self._counts[key]
+ self._names[key] = name
+ else:
+ self._counts[key] += 1
+ self._counts.move_to_end(key)
+ return None
+
+ elif key in self._names:
+ self._names.move_to_end(key)
+ return None
+
+ else:
+ if prep is Prepare.SHOULD:
+ self._names[key] = name
+ else:
+ self._counts[key] = 1
+ return key
+
+ def validate(
+ self,
+ key: Key,
+ prep: Prepare,
+ name: bytes,
+ results: Sequence["PGresult"],
+ ) -> None:
+ """Validate cached entry with 'key' by checking query 'results'.
+
+ Possibly return a command to perform maintenance on database side.
+
+ Note: this method is only called in pipeline mode.
+ """
+ if self._should_discard(prep, results):
+ return
+
+ if not self._check_results(results):
+ self._names.pop(key, None)
+ self._counts.pop(key, None)
+ else:
+ self._rotate()
+
+ def clear(self) -> bool:
+ """Clear the cache of the maintenance commands.
+
+ Clear the internal state and prepare a command to clear the state of
+ the server.
+ """
+ self._counts.clear()
+ if self._names:
+ self._names.clear()
+ self._maint_commands.clear()
+ self._maint_commands.append(b"DEALLOCATE ALL")
+ return True
+ else:
+ return False
+
+ def get_maintenance_commands(self) -> Iterator[bytes]:
+ """
+ Iterate over the commands needed to align the server state to our state
+ """
+ while self._maint_commands:
+ yield self._maint_commands.popleft()
diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py
new file mode 100644
index 0000000..2a7554c
--- /dev/null
+++ b/psycopg/psycopg/_queries.py
@@ -0,0 +1,375 @@
+"""
+Utility module to manipulate queries
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
+from typing import Sequence, Tuple, Union, TYPE_CHECKING
+from functools import lru_cache
+
+from . import pq
+from . import errors as e
+from .sql import Composable
+from .abc import Buffer, Query, Params
+from ._enums import PyFormat
+from ._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+
+
+class QueryPart(NamedTuple):
+ pre: bytes
+ item: Union[int, str]
+ format: PyFormat
+
+
+class PostgresQuery:
+ """
+ Helper to convert a Python query and parameters into Postgres format.
+ """
+
+ __slots__ = """
+ query params types formats
+ _tx _want_formats _parts _encoding _order
+ """.split()
+
+ def __init__(self, transformer: "Transformer"):
+ self._tx = transformer
+
+ self.params: Optional[Sequence[Optional[Buffer]]] = None
+ # these are tuples so they can be used as keys e.g. in prepared stmts
+ self.types: Tuple[int, ...] = ()
+
+ # The format requested by the user and the ones to really pass Postgres
+ self._want_formats: Optional[List[PyFormat]] = None
+ self.formats: Optional[Sequence[pq.Format]] = None
+
+ self._encoding = conn_encoding(transformer.connection)
+ self._parts: List[QueryPart]
+ self.query = b""
+ self._order: Optional[List[str]] = None
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (
+ self.query,
+ self._want_formats,
+ self._order,
+ self._parts,
+ ) = _query2pg(bquery, self._encoding)
+ else:
+ self.query = bquery
+ self._want_formats = self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ assert self._want_formats is not None
+ self.params = self._tx.dump_sequence(params, self._want_formats)
+ self.types = self._tx.types or ()
+ self.formats = self._tx.formats
+ else:
+ self.params = None
+ self.types = ()
+ self.formats = None
+
+
+class PostgresClientQuery(PostgresQuery):
+ """
+ PostgresQuery subclass merging query and arguments client-side.
+ """
+
+ __slots__ = ("template",)
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (self.template, self._order, self._parts) = _query2pg_client(
+ bquery, self._encoding
+ )
+ else:
+ self.query = bquery
+ self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ self.params = tuple(
+ self._tx.as_literal(p) if p is not None else b"NULL" for p in params
+ )
+ self.query = self.template % self.params
+ else:
+ self.params = None
+
+
+@lru_cache()
+def _query2pg(
+ query: bytes, encoding: str
+) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into something Postgres understands.
+
+ - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+ format (``$1``, ``$2``)
+ - placeholders can be %s, %t, or %b (auto, text or binary)
+ - return ``query`` (bytes), ``formats`` (list of formats) ``order``
+ (sequence of names used in the query, in the position they appear)
+ ``parts`` (splits of queries and placeholders).
+ """
+ parts = _split_query(query, encoding)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+ formats = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"$%d" % (part.item + 1))
+ formats.append(part.format)
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"$%d" % (len(seen) + 1)
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ formats.append(part.format)
+ else:
+ if seen[part.item][1] != part.format:
+ raise e.ProgrammingError(
+ f"placeholder '{part.item}' cannot have different formats"
+ )
+ chunks.append(seen[part.item][0])
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), formats, order, parts
+
+
+@lru_cache()
+def _query2pg_client(
+ query: bytes, encoding: str
+) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into a template to perform client-side binding
+ """
+ parts = _split_query(query, encoding, collapse_double_percent=False)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"%s")
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"%s"
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ else:
+ chunks.append(seen[part.item][0])
+ order.append(part.item)
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), order, parts
+
+
+def _validate_and_reorder_params(
+ parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+) -> Sequence[Any]:
+ """
+ Verify the compatibility between a query and a set of params.
+ """
+ # Try concrete types, then abstract types
+ t = type(vars)
+ if t is list or t is tuple:
+ sequence = True
+ elif t is dict:
+ sequence = False
+ elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+ sequence = True
+ elif isinstance(vars, Mapping):
+ sequence = False
+ else:
+ raise TypeError(
+ "query parameters should be a sequence or a mapping,"
+ f" got {type(vars).__name__}"
+ )
+
+ if sequence:
+ if len(vars) != len(parts) - 1:
+ raise e.ProgrammingError(
+ f"the query has {len(parts) - 1} placeholders but"
+ f" {len(vars)} parameters were passed"
+ )
+ if vars and not isinstance(parts[0].item, int):
+ raise TypeError("named placeholders require a mapping of parameters")
+ return vars # type: ignore[return-value]
+
+ else:
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+ raise TypeError(
+ "positional placeholders (%s) require a sequence of parameters"
+ )
+ try:
+ return [vars[item] for item in order or ()] # type: ignore[call-overload]
+ except KeyError:
+ raise e.ProgrammingError(
+ "query parameter missing:"
+ f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+ )
+
+
+_re_placeholder = re.compile(
+ rb"""(?x)
+ % # a literal %
+ (?:
+ (?:
+ \( ([^)]+) \) # or a name in (braces)
+ . # followed by a format
+ )
+ |
+ (?:.) # or any char, really
+ )
+ """
+)
+
+
+def _split_query(
+ query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
+) -> List[QueryPart]:
+ parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
+ cur = 0
+
+ # pairs [(fragment, match], with the last match None
+ m = None
+ for m in _re_placeholder.finditer(query):
+ pre = query[cur : m.span(0)[0]]
+ parts.append((pre, m))
+ cur = m.span(0)[1]
+ if m:
+ parts.append((query[cur:], None))
+ else:
+ parts.append((query, None))
+
+ rv = []
+
+ # drop the "%%", validate
+ i = 0
+ phtype = None
+ while i < len(parts):
+ pre, m = parts[i]
+ if m is None:
+ # last part
+ rv.append(QueryPart(pre, 0, PyFormat.AUTO))
+ break
+
+ ph = m.group(0)
+ if ph == b"%%":
+ # unescape '%%' to '%' if necessary, then merge the parts
+ if collapse_double_percent:
+ ph = b"%"
+ pre1, m1 = parts[i + 1]
+ parts[i + 1] = (pre + ph + pre1, m1)
+ del parts[i]
+ continue
+
+ if ph == b"%(":
+ raise e.ProgrammingError(
+ "incomplete placeholder:"
+ f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
+ )
+ elif ph == b"% ":
+ # explicit messasge for a typical error
+ raise e.ProgrammingError(
+ "incomplete placeholder: '%'; if you want to use '%' as an"
+ " operator you can double it up, i.e. use '%%'"
+ )
+ elif ph[-1:] not in b"sbt":
+ raise e.ProgrammingError(
+ "only '%s', '%b', '%t' are allowed as placeholders, got"
+ f" '{m.group(0).decode(encoding)}'"
+ )
+
+ # Index or name
+ item: Union[int, str]
+ item = m.group(1).decode(encoding) if m.group(1) else i
+
+ if not phtype:
+ phtype = type(item)
+ elif phtype is not type(item):
+ raise e.ProgrammingError(
+ "positional and named placeholders cannot be mixed"
+ )
+
+ format = _ph_to_fmt[ph[-1:]]
+ rv.append(QueryPart(pre, item, format))
+ i += 1
+
+ return rv
+
+
+_ph_to_fmt = {
+ b"s": PyFormat.AUTO,
+ b"t": PyFormat.TEXT,
+ b"b": PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py
new file mode 100644
index 0000000..28a6084
--- /dev/null
+++ b/psycopg/psycopg/_struct.py
@@ -0,0 +1,57 @@
+"""
+Utility functions to deal with binary structs.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from typing import Callable, cast, Optional, Tuple
+from typing_extensions import TypeAlias
+
+from .abc import Buffer
+from . import errors as e
+from ._compat import Protocol
+
+PackInt: TypeAlias = Callable[[int], bytes]
+UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
+PackFloat: TypeAlias = Callable[[float], bytes]
+UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
+
+
+class UnpackLen(Protocol):
+ def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
+ ...
+
+
+pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
+pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
+pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
+
+unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
+unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
+unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
+unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
+unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
+unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
+unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
+
+_struct_len = struct.Struct("!i")
+pack_len = cast(Callable[[int], bytes], _struct_len.pack)
+unpack_len = cast(UnpackLen, _struct_len.unpack_from)
+
+
+def pack_float4_bug_304(x: float) -> bytes:
+ raise e.InterfaceError(
+ "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
+ " and psycopg-binary packages are not affected by this issue."
+ " See https://github.com/psycopg/psycopg/issues/304"
+ )
+
+
+# If issue #304 is detected, raise an error instead of dumping wrong data.
+if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
+ pack_float4 = pack_float4_bug_304
diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py
new file mode 100644
index 0000000..3528188
--- /dev/null
+++ b/psycopg/psycopg/_tpc.py
@@ -0,0 +1,116 @@
+"""
+psycopg two-phase commit support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+import datetime as dt
+from base64 import b64encode, b64decode
+from typing import Optional, Union
+from dataclasses import dataclass, replace
+
+_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
+
+
+@dataclass(frozen=True)
+class Xid:
+ """A two-phase commit transaction identifier.
+
+ The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
+ `bqual`).
+
+ """
+
+ format_id: Optional[int]
+ gtrid: str
+ bqual: Optional[str]
+ prepared: Optional[dt.datetime] = None
+ owner: Optional[str] = None
+ database: Optional[str] = None
+
+ @classmethod
+ def from_string(cls, s: str) -> "Xid":
+ """Try to parse an XA triple from the string.
+
+ This may fail for several reasons. In such case return an unparsed Xid.
+ """
+ try:
+ return cls._parse_string(s)
+ except Exception:
+ return Xid(None, s, None)
+
+ def __str__(self) -> str:
+ return self._as_tid()
+
+ def __len__(self) -> int:
+ return 3
+
+ def __getitem__(self, index: int) -> Union[int, str, None]:
+ return (self.format_id, self.gtrid, self.bqual)[index]
+
+ @classmethod
+ def _parse_string(cls, s: str) -> "Xid":
+ m = _re_xid.match(s)
+ if not m:
+ raise ValueError("bad Xid format")
+
+ format_id = int(m.group(1))
+ gtrid = b64decode(m.group(2)).decode()
+ bqual = b64decode(m.group(3)).decode()
+ return cls.from_parts(format_id, gtrid, bqual)
+
+ @classmethod
+ def from_parts(
+ cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
+ ) -> "Xid":
+ if format_id is not None:
+ if bqual is None:
+ raise TypeError("if format_id is specified, bqual must be too")
+ if not 0 <= format_id < 0x80000000:
+ raise ValueError("format_id must be a non-negative 32-bit integer")
+ if len(bqual) > 64:
+ raise ValueError("bqual must be not longer than 64 chars")
+ if len(gtrid) > 64:
+ raise ValueError("gtrid must be not longer than 64 chars")
+
+ elif bqual is None:
+ raise TypeError("if format_id is None, bqual must be None too")
+
+ return Xid(format_id, gtrid, bqual)
+
+ def _as_tid(self) -> str:
+ """
+ Return the PostgreSQL transaction_id for this XA xid.
+
+ PostgreSQL wants just a string, while the DBAPI supports the XA
+ standard and thus a triple. We use the same conversion algorithm
+ implemented by JDBC in order to allow some form of interoperation.
+
+ see also: the pgjdbc implementation
+ http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
+ postgresql/xa/RecoveredXid.java?rev=1.2
+ """
+ if self.format_id is None or self.bqual is None:
+ # Unparsed xid: return the gtrid.
+ return self.gtrid
+
+ # XA xid: mash together the components.
+ egtrid = b64encode(self.gtrid.encode()).decode()
+ ebqual = b64encode(self.bqual.encode()).decode()
+
+ return f"{self.format_id}_{egtrid}_{ebqual}"
+
+ @classmethod
+ def _get_recover_query(cls) -> str:
+ return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
+
+ @classmethod
+ def _from_record(
+ cls, gid: str, prepared: dt.datetime, owner: str, database: str
+ ) -> "Xid":
+ xid = Xid.from_string(gid)
+ return replace(xid, prepared=prepared, owner=owner, database=database)
+
+
+Xid.__module__ = "psycopg"
diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py
new file mode 100644
index 0000000..19bd6ae
--- /dev/null
+++ b/psycopg/psycopg/_transform.py
@@ -0,0 +1,350 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+from typing import DefaultDict, TYPE_CHECKING
+from collections import defaultdict
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import postgres
+from . import errors as e
+from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
+from .rows import Row, RowMaker
+from .postgres import INVALID_OID, TEXT_OID
+from ._encodings import pgconn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Dumper, Loader
+ from .adapt import AdaptersMap
+ from .pq.abc import PGresult
+ from .connection import BaseConnection
+
+DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
+OidDumperCache: TypeAlias = Dict[int, "Dumper"]
+LoaderCache: TypeAlias = Dict[int, "Loader"]
+
+TEXT = pq.Format.TEXT
+PY_TEXT = PyFormat.TEXT
+
+
+class Transformer(AdaptContext):
+ """
+ An object that can adapt efficiently between Python and PostgreSQL.
+
+ The life cycle of the object is the query, so it is assumed that attributes
+ such as the server version or the connection encoding will not change. The
+ object have its state so adapting several values of the same type can be
+ optimised.
+
+ """
+
+ __module__ = "psycopg.adapt"
+
+ __slots__ = """
+ types formats
+ _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
+ _oid_dumpers _oid_types _row_dumpers _row_loaders
+ """.split()
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ _adapters: "AdaptersMap"
+ _pgresult: Optional["PGresult"]
+ _none_oid: int
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ self._pgresult = self.types = self.formats = None
+
+ # WARNING: don't store context, or you'll create a loop with the Cursor
+ if context:
+ self._adapters = context.adapters
+ self._conn = context.connection
+ else:
+ self._adapters = postgres.adapters
+ self._conn = None
+
+ # mapping fmt, class -> Dumper instance
+ self._dumpers: DefaultDict[PyFormat, DumperCache]
+ self._dumpers = defaultdict(dict)
+
+ # mapping fmt, oid -> Dumper instance
+ # Not often used, so create it only if needed.
+ self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
+ self._oid_dumpers = None
+
+ # mapping fmt, oid -> Loader instance
+ self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
+
+ self._row_dumpers: Optional[List["Dumper"]] = None
+
+ # sequence of load functions from value to python
+ # the length of the result columns
+ self._row_loaders: List[LoadFunc] = []
+
+ # mapping oid -> type sql representation
+ self._oid_types: Dict[int, bytes] = {}
+
+ self._encoding = ""
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ """
+ Return a Transformer from an AdaptContext.
+
+ If the context is a Transformer instance, just return it.
+ """
+ if isinstance(context, Transformer):
+ return context
+ else:
+ return cls(context)
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return self._conn
+
+ @property
+ def encoding(self) -> str:
+ if not self._encoding:
+ conn = self.connection
+ self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+ return self._encoding
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self._adapters
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ return self._pgresult
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None,
+ ) -> None:
+ self._pgresult = result
+
+ if not result:
+ self._nfields = self._ntuples = 0
+ if set_loaders:
+ self._row_loaders = []
+ return
+
+ self._ntuples = result.ntuples
+ nf = self._nfields = result.nfields
+
+ if not set_loaders:
+ return
+
+ if not nf:
+ self._row_loaders = []
+ return
+
+ fmt: pq.Format
+ fmt = result.fformat(0) if format is None else format # type: ignore
+ self._row_loaders = [
+ self.get_loader(result.ftype(i), fmt).load for i in range(nf)
+ ]
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
+ self.types = tuple(types)
+ self.formats = [format] * len(types)
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_loaders = [self.get_loader(oid, format).load for oid in types]
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ nparams = len(params)
+ out: List[Optional[Buffer]] = [None] * nparams
+
+ # If we have dumpers, it means set_dumper_types had been called, in
+ # which case self.types and self.formats are set to sequences of the
+ # right size.
+ if self._row_dumpers:
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ out[i] = self._row_dumpers[i].dump(param)
+ return out
+
+ types = [self._get_none_oid()] * nparams
+ pqformats = [TEXT] * nparams
+
+ for i in range(nparams):
+ param = params[i]
+ if param is None:
+ continue
+ dumper = self.get_dumper(param, formats[i])
+ out[i] = dumper.dump(param)
+ types[i] = dumper.oid
+ pqformats[i] = dumper.format
+
+ self.types = tuple(types)
+ self.formats = pqformats
+
+ return out
+
+ def as_literal(self, obj: Any) -> bytes:
+ dumper = self.get_dumper(obj, PY_TEXT)
+ rv = dumper.quote(obj)
+ # If the result is quoted, and the oid not unknown or text,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ oid = dumper.oid
+ if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
+ try:
+ type_sql = self._oid_types[oid]
+ except KeyError:
+ ti = self.adapters.types.get(oid)
+ if ti:
+ if oid < 8192:
+ # builtin: prefer "timestamptz" to "timestamp with time zone"
+ type_sql = ti.name.encode(self.encoding)
+ else:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+ else:
+ type_sql = b""
+ self._oid_types[oid] = type_sql
+
+ if type_sql:
+ rv = b"%s::%s" % (rv, type_sql)
+
+ if not isinstance(rv, bytes):
+ rv = bytes(rv)
+ return rv
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Return a Dumper instance to dump `!obj`.
+ """
+ # Normally, the type of the object dictates how to dump it
+ key = type(obj)
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._dumpers[format]
+ try:
+ dumper = cache[key]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper(key, format)
+ cache[key] = dumper = dcls(key, self)
+
+ # Check if the dumper requires an upgrade to handle this specific value
+ key1 = dumper.get_key(obj, format)
+ if key1 is key:
+ return dumper
+
+ # If it does, ask the dumper to create its own upgraded version
+ try:
+ return cache[key1]
+ except KeyError:
+ dumper = cache[key1] = dumper.upgrade(obj, format)
+ return dumper
+
+ def _get_none_oid(self) -> int:
+ try:
+ return self._none_oid
+ except AttributeError:
+ pass
+
+ try:
+ rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
+ except KeyError:
+ raise e.InterfaceError("None dumper not found")
+
+ return rv
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
+ """
+ Return a Dumper to dump an object to the type with given oid.
+ """
+ if not self._oid_dumpers:
+ self._oid_dumpers = ({}, {})
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._oid_dumpers[format]
+ try:
+ return cache[oid]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper_by_oid(oid, format)
+ cache[oid] = dumper = dcls(NoneType, self)
+
+ return dumper
+
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
+ res = self._pgresult
+ if not res:
+ raise e.InterfaceError("result not set")
+
+ if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
+ raise e.InterfaceError(
+ f"rows must be included between 0 and {self._ntuples}"
+ )
+
+ records = []
+ for row in range(row0, row1):
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+ records.append(make_row(record))
+
+ return records
+
+ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
+ res = self._pgresult
+ if not res:
+ return None
+
+ if not 0 <= row < self._ntuples:
+ return None
+
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+
+ return make_row(record)
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ if len(self._row_loaders) != len(record):
+ raise e.ProgrammingError(
+ f"cannot load sequence of {len(record)} items:"
+ f" {len(self._row_loaders)} loaders registered"
+ )
+
+ return tuple(
+ (self._row_loaders[i](val) if val is not None else None)
+ for i, val in enumerate(record)
+ )
+
+ def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ try:
+ return self._loaders[format][oid]
+ except KeyError:
+ pass
+
+ loader_cls = self._adapters.get_loader(oid, format)
+ if not loader_cls:
+ loader_cls = self._adapters.get_loader(INVALID_OID, format)
+ if not loader_cls:
+ raise e.InterfaceError("unknown oid loader not found")
+ loader = self._loaders[format][oid] = loader_cls(oid, self)
+ return loader
diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py
new file mode 100644
index 0000000..2f1a24d
--- /dev/null
+++ b/psycopg/psycopg/_typeinfo.py
@@ -0,0 +1,461 @@
+"""
+Information about PostgreSQL types
+
+These types allow to read information from the system catalog and provide
+information to the adapters if needed.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+from enum import Enum
+from typing import Any, Dict, Iterator, Optional, overload
+from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import errors as e
+from .abc import AdaptContext
+from .rows import dict_row
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+ from .sql import Identifier
+
+T = TypeVar("T", bound="TypeInfo")
+RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
+
+
+class TypeInfo:
+ """
+ Hold information about a PostgreSQL base type.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ delimiter: str = ",",
+ ):
+ self.name = name
+ self.oid = oid
+ self.array_oid = array_oid
+ self.regtype = regtype or name
+ self.delimiter = delimiter
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__class__.__qualname__}:"
+ f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
+ )
+
+ @overload
+ @classmethod
+ def fetch(
+ cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
+ ) -> Optional[T]:
+ ...
+
+ @overload
+ @classmethod
+ async def fetch(
+ cls: Type[T],
+ conn: "AsyncConnection[Any]",
+ name: Union[str, "Identifier"],
+ ) -> Optional[T]:
+ ...
+
+ @classmethod
+ def fetch(
+ cls: Type[T],
+ conn: "Union[Connection[Any], AsyncConnection[Any]]",
+ name: Union[str, "Identifier"],
+ ) -> Any:
+ """Query a system catalog to read information about a type."""
+ from .sql import Composable
+ from .connection_async import AsyncConnection
+
+ if isinstance(name, Composable):
+ name = name.as_string(conn)
+
+ if isinstance(conn, AsyncConnection):
+ return cls._fetch_async(conn, name)
+
+ # This might result in a nested transaction. What we want is to leave
+ # the function with the connection in the state we found (either idle
+ # or intrans)
+ try:
+ with conn.transaction():
+ with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ async def _fetch_async(
+ cls: Type[T], conn: "AsyncConnection[Any]", name: str
+ ) -> Optional[T]:
+ """
+ Query a system catalog to read information about a type.
+
+ Similar to `fetch()` but can use an asynchronous connection.
+ """
+ try:
+ async with conn.transaction():
+ async with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ await cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = await cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ def _from_records(
+ cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
+ ) -> Optional[T]:
+ if len(recs) == 1:
+ return cls(**recs[0])
+ elif not recs:
+ return None
+ else:
+ raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
+
+ def register(self, context: Optional[AdaptContext] = None) -> None:
+ """
+ Register the type information, globally or in the specified `!context`.
+ """
+ if context:
+ types = context.adapters.types
+ else:
+ from . import postgres
+
+ types = postgres.types
+
+ types.add(self)
+
+ if self.array_oid:
+ from .types.array import register_array
+
+ register_array(self, context)
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ typname AS name, oid, typarray AS array_oid,
+ oid::regtype::text AS regtype, typdelim AS delimiter
+FROM pg_type t
+WHERE t.oid = %(name)s::regtype
+ORDER BY t.oid
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ """Method called by the `!registry` when the object is added there."""
+ pass
+
+
+class RangeInfo(TypeInfo):
+ """Manage information about a range type."""
+
+ __module__ = "psycopg.types.range"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngtypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map ranges subtypes to info
+ registry._registry[RangeInfo, self.subtype_oid] = self
+
+
+class MultirangeInfo(TypeInfo):
+ """Manage information about a multirange type."""
+
+ # TODO: expose to multirange module once added
+ # __module__ = "psycopg.types.multirange"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ range_oid: int,
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.range_oid = range_oid
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ if conn.info.server_version < 140000:
+ raise e.NotSupportedError(
+ "multirange types are only available from PostgreSQL 14"
+ )
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngmultitypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map multiranges ranges and subtypes to info
+ registry._registry[MultirangeInfo, self.range_oid] = self
+ registry._registry[MultirangeInfo, self.subtype_oid] = self
+
+
+class CompositeInfo(TypeInfo):
+ """Manage information about a composite type."""
+
+ __module__ = "psycopg.types.composite"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ field_names: Sequence[str],
+ field_types: Sequence[int],
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.field_names = field_names
+ self.field_types = field_types
+ # Will be set by register() if the `factory` is a type
+ self.python_type: Optional[type] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ coalesce(a.fnames, '{}') AS field_names,
+ coalesce(a.ftypes, '{}') AS field_types
+FROM pg_type t
+LEFT JOIN (
+ SELECT
+ attrelid,
+ array_agg(attname) AS fnames,
+ array_agg(atttypid) AS ftypes
+ FROM (
+ SELECT a.attrelid, a.attname, a.atttypid
+ FROM pg_attribute a
+ JOIN pg_type t ON t.typrelid = a.attrelid
+ WHERE t.oid = %(name)s::regtype
+ AND a.attnum > 0
+ AND NOT a.attisdropped
+ ORDER BY a.attnum
+ ) x
+ GROUP BY attrelid
+) a ON a.attrelid = t.typrelid
+WHERE t.oid = %(name)s::regtype
+"""
+
+
+class EnumInfo(TypeInfo):
+ """Manage information about an enum type."""
+
+ __module__ = "psycopg.types.enum"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ labels: Sequence[str],
+ ):
+ super().__init__(name, oid, array_oid)
+ self.labels = labels
+ # Will be set by register_enum()
+ self.enum: Optional[Type[Enum]] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT name, oid, array_oid, array_agg(label) AS labels
+FROM (
+ SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ e.enumlabel AS label
+ FROM pg_type t
+ LEFT JOIN pg_enum e
+ ON e.enumtypid = t.oid
+ WHERE t.oid = %(name)s::regtype
+ ORDER BY e.enumsortorder
+) x
+GROUP BY name, oid, array_oid
+"""
+
+
+class TypesRegistry:
+ """
+ Container for the information about types in a database.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(self, template: Optional["TypesRegistry"] = None):
+ self._registry: Dict[RegistryKey, TypeInfo]
+
+ # Make a shallow copy: it will become a proper copy if the registry
+ # is edited.
+ if template:
+ self._registry = template._registry
+ self._own_state = False
+ template._own_state = False
+ else:
+ self.clear()
+
+ def clear(self) -> None:
+ self._registry = {}
+ self._own_state = True
+
+ def add(self, info: TypeInfo) -> None:
+ self._ensure_own_state()
+ if info.oid:
+ self._registry[info.oid] = info
+ if info.array_oid:
+ self._registry[info.array_oid] = info
+ self._registry[info.name] = info
+
+ if info.regtype and info.regtype not in self._registry:
+ self._registry[info.regtype] = info
+
+ # Allow info to customise further their relation with the registry
+ info._added(self)
+
+ def __iter__(self) -> Iterator[TypeInfo]:
+ seen = set()
+ for t in self._registry.values():
+ if id(t) not in seen:
+ seen.add(id(t))
+ yield t
+
+ @overload
+ def __getitem__(self, key: Union[str, int]) -> TypeInfo:
+ ...
+
+ @overload
+ def __getitem__(self, key: Tuple[Type[T], int]) -> T:
+ ...
+
+ def __getitem__(self, key: RegistryKey) -> TypeInfo:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Raise KeyError if not found.
+ """
+ if isinstance(key, str):
+ if key.endswith("[]"):
+ key = key[:-2]
+ elif not isinstance(key, (int, tuple)):
+ raise TypeError(f"the key must be an oid or a name, got {type(key)}")
+ try:
+ return self._registry[key]
+ except KeyError:
+ raise KeyError(f"couldn't find the type {key!r} in the types registry")
+
+ @overload
+ def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
+ ...
+
+ def get(self, key: RegistryKey) -> Optional[TypeInfo]:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Unlike `__getitem__`, return None if not found.
+ """
+ try:
+ return self[key]
+ except KeyError:
+ return None
+
+ def get_oid(self, name: str) -> int:
+ """
+ Return the oid of a PostgreSQL type by name.
+
+ :param key: the name of the type to look for.
+
+ Return the array oid if the type ends with "``[]``"
+
+ Raise KeyError if the name is unknown.
+ """
+ t = self[name]
+ if name.endswith("[]"):
+ return t.array_oid
+ else:
+ return t.oid
+
+ def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
+ """
+ Return info about a `TypeInfo` subclass by its element name or oid.
+
+ :param cls: the subtype of `!TypeInfo` to look for. Currently
+ supported are `~psycopg.types.range.RangeInfo` and
+ `~psycopg.types.multirange.MultirangeInfo`.
+ :param subtype: The name or OID of the subtype of the element to look for.
+ :return: The `!TypeInfo` object of class `!cls` whose subtype is
+ `!subtype`. `!None` if the element or its range are not found.
+ """
+ try:
+ info = self[subtype]
+ except KeyError:
+ return None
+ return self.get((cls, info.oid))
+
+ def _ensure_own_state(self) -> None:
+ # Time to write! so, copy.
+ if not self._own_state:
+ self._registry = self._registry.copy()
+ self._own_state = True
diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py
new file mode 100644
index 0000000..813ed62
--- /dev/null
+++ b/psycopg/psycopg/_tz.py
@@ -0,0 +1,44 @@
+"""
+Timezone utility functions.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import Dict, Optional, Union
+from datetime import timezone, tzinfo
+
+from .pq.abc import PGconn
+from ._compat import ZoneInfo
+
+logger = logging.getLogger("psycopg")
+
+_timezones: Dict[Union[None, bytes], tzinfo] = {
+ None: timezone.utc,
+ b"UTC": timezone.utc,
+}
+
+
+def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
+ """Return the Python timezone info of the connection's timezone."""
+ tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
+ try:
+ return _timezones[tzname]
+ except KeyError:
+ sname = tzname.decode() if tzname else "UTC"
+ try:
+ zi: tzinfo = ZoneInfo(sname)
+ except (KeyError, OSError):
+ logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
+ zi = timezone.utc
+ except Exception as ex:
+ logger.warning(
+ "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
+ sname,
+ type(ex).__name__,
+ ex,
+ )
+ zi = timezone.utc
+
+ _timezones[tzname] = zi
+ return zi
diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py
new file mode 100644
index 0000000..f861741
--- /dev/null
+++ b/psycopg/psycopg/_wrappers.py
@@ -0,0 +1,137 @@
+"""
+Wrappers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+# These types are implemented here but exposed by `psycopg.types.numeric`.
+# They are defined here to avoid a circular import.
+_MODULE = "psycopg.types.numeric"
+
+
+class Int2(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int2":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int4(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int8(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class IntNumeric(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "IntNumeric":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float4(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float8(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Oid(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Oid":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py
new file mode 100644
index 0000000..80c8fbf
--- /dev/null
+++ b/psycopg/psycopg/abc.py
@@ -0,0 +1,266 @@
+"""
+Protocol objects representing different implementations of the same classes.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Generator, Mapping
+from typing import List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._enums import PyFormat as PyFormat
+from ._compat import Protocol, LiteralString
+
+if TYPE_CHECKING:
+ from . import sql
+ from .rows import Row, RowMaker
+ from .pq.abc import PGresult
+ from .waiting import Wait, Ready
+ from .connection import BaseConnection
+ from ._adapters_map import AdaptersMap
+
+NoneType: type = type(None)
+
+# An object implementing the buffer protocol
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
+Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
+PipelineCommand: TypeAlias = Callable[[], None]
+DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+
+# Waiting protocol types
+
+RV = TypeVar("RV")
+
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+"""Generator for processes where the connection file number can change.
+
+This can happen in connection and reset, but not in normal querying.
+"""
+
+PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+"""Generator for processes where the connection file number won't change.
+"""
+
+
+class WaitFunc(Protocol):
+ """
+ Wait on the connection which generated `PQgen` and return its final result.
+ """
+
+ def __call__(
+ self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
+ ) -> RV:
+ ...
+
+
+# Adaptation types
+
+DumpFunc: TypeAlias = Callable[[Any], Buffer]
+LoadFunc: TypeAlias = Callable[[Buffer], Any]
+
+
+class AdaptContext(Protocol):
+ """
+ A context describing how types are adapted.
+
+ Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
+ `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
+
+ Note that this is a `~typing.Protocol`, so objects implementing
+ `!AdaptContext` don't need to explicitly inherit from this class.
+
+ """
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ """The adapters configuration that this object uses."""
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ """The connection used by this object, if available.
+
+ :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
+ """
+ ...
+
+
+class Dumper(Protocol):
+ """
+ Convert Python objects of type `!cls` to PostgreSQL representation.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `dump()` method produces,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ oid: int
+ """The oid to pass to the server, if known; 0 otherwise (class attribute)."""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ ...
+
+ def dump(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to PostgreSQL representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to escaped representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
+ """Return an alternative key to upgrade the dumper to represent `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Normally the type of the object is all it takes to define how to dump
+ the object to the database. For instance, a Python `~datetime.date` can
+ be simply converted into a PostgreSQL :sql:`date`.
+
+ In a few cases, just the type is not enough. For example:
+
+ - A Python `~datetime.datetime` could be represented as a
+ :sql:`timestamptz` or a :sql:`timestamp`, according to whether it
+ specifies a `!tzinfo` or not.
+
+ - A Python int could be stored as several Postgres types: int2, int4,
+ int8, numeric. If a type too small is used, it may result in an
+ overflow. If a type too large is used, PostgreSQL may not want to
+ cast it to a smaller type.
+
+ - Python lists should be dumped according to the type they contain to
+ convert them to e.g. array of strings, array of ints (and which
+ size of int?...)
+
+ In these cases, a dumper can implement `!get_key()` and return a new
+ class, or sequence of classes, that can be used to identify the same
+ dumper again. If the mechanism is not needed, the method should return
+ the same `!cls` object passed in the constructor.
+
+ If a dumper implements `get_key()` it should also implement
+ `upgrade()`.
+
+ """
+ ...
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """Return a new dumper to manage `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Once `Transformer.get_dumper()` has been notified by `get_key()` that
+ this Dumper class cannot handle `!obj` itself, it will invoke
+ `!upgrade()`, which should return a new `Dumper` instance, which will
+ be reused for every objects for which `!get_key()` returns the same
+ result.
+ """
+ ...
+
+
+class Loader(Protocol):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `load()` method can convert,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ ...
+
+ def load(self, data: Buffer) -> Any:
+ """
+ Convert the data returned by the database into a Python object.
+
+ :param data: the data to convert.
+ """
+ ...
+
+
+class Transformer(Protocol):
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ ...
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ ...
+
+ @property
+ def encoding(self) -> str:
+ ...
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ ...
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ ...
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None
+ ) -> None:
+ ...
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ ...
+
+ def as_literal(self, obj: Any) -> bytes:
+ ...
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
+ ...
+
+ def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
+ ...
+
+ def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
+ ...
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ ...
+
+ def get_loader(self, oid: int, format: pq.Format) -> Loader:
+ ...
diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py
new file mode 100644
index 0000000..7ec4a55
--- /dev/null
+++ b/psycopg/psycopg/adapt.py
@@ -0,0 +1,162 @@
+"""
+Entry point into the adaptation system.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type, TYPE_CHECKING
+
+from . import pq, abc
+from . import _adapters_map
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+AdaptersMap = _adapters_map.AdaptersMap
+Buffer = abc.Buffer
+
+ORD_BS = ord("\\")
+
+
+class Dumper(abc.Dumper, ABC):
+ """
+ Convert Python object of the type `!cls` to PostgreSQL representation.
+ """
+
+ oid: int = 0
+ """The oid to pass to the server, if known."""
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data dumped."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ self.cls = cls
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<{type(self).__module__}.{type(self).__qualname__}"
+ f" (oid={self.oid}) at 0x{id(self):x}>"
+ )
+
+ @abstractmethod
+ def dump(self, obj: Any) -> Buffer:
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """
+ By default return the `dump()` value quoted and sanitised, so
+ that the result can be used to build a SQL string. This works well
+ for most types and you won't likely have to implement this method in a
+ subclass.
+ """
+ value = self.dump(obj)
+
+ if self.connection:
+ esc = pq.Escaping(self.connection.pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = pq.Escaping()
+ out = esc.escape_string(value)
+
+ # b"\\" in memoryview doesn't work so search for the ascii value
+ if ORD_BS not in out:
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ return b"'" + out + b"'"
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ rv: bytes = b" E'" + out + b"'"
+ if esc.escape_string(b"\\") == b"\\":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+ def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
+ """
+ Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation returns the `!cls` passed in the constructor.
+ Subclasses needing to specialise the PostgreSQL type according to the
+ *value* of the object dumped (not only according to to its type)
+ should override this class.
+
+ """
+ return self.cls
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation just returns `!self`. If a subclass implements
+ `get_key()` it should probably override `!upgrade()` too.
+ """
+ return self
+
+
+class Loader(abc.Loader, ABC):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data loaded."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ self.oid = oid
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ @abstractmethod
+ def load(self, data: Buffer) -> Any:
+ """Convert a PostgreSQL value to a Python object."""
+ ...
+
+
+Transformer: Type["abc.Transformer"]
+
+# Override it with fast object if available
+if _psycopg:
+ Transformer = _psycopg.Transformer
+else:
+ from . import _transform
+
+ Transformer = _transform.Transformer
+
+
+class RecursiveDumper(Dumper):
+ """Dumper with a transformer to help dumping recursive types."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self._tx = Transformer.from_context(context)
+
+
+class RecursiveLoader(Loader):
+ """Loader with a transformer to help loading recursive types."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer.from_context(context)
diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py
new file mode 100644
index 0000000..6271ec5
--- /dev/null
+++ b/psycopg/psycopg/client_cursor.py
@@ -0,0 +1,95 @@
+"""
+psycopg client-side binding cursors
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from typing import Optional, Tuple, TYPE_CHECKING
+from functools import partial
+
+from ._queries import PostgresQuery, PostgresClientQuery
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params
+from .rows import Row
+from .cursor import BaseCursor, Cursor
+from ._preparing import Prepare
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from typing import Any # noqa: F401
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+
+class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
+ def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
+ """
+ Return the query and parameters merged.
+
+ Parameters are adapted and merged to the query the same way that
+ `!execute()` would do.
+
+ """
+ self._tx = adapt.Transformer(self)
+ pgq = self._convert_query(query, params)
+ return pgq.query.decode(self._tx.encoding)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if fmt == BINARY:
+ raise e.NotSupportedError(
+ "client-side cursors don't support binary results"
+ )
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(self._pgconn.send_query_params, query.query, None)
+ )
+ elif force_extended:
+ self._pgconn.send_query_params(query.query, None)
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return (Prepare.NO, b"")
+
+
+class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+
+
+class AsyncClientCursor(
+ ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py
new file mode 100644
index 0000000..78ad577
--- /dev/null
+++ b/psycopg/psycopg/connection.py
@@ -0,0 +1,1031 @@
+"""
+psycopg connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+import threading
+from types import TracebackType
+from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
+from typing import overload, TYPE_CHECKING
+from weakref import ref, ReferenceType
+from warnings import warn
+from functools import partial
+from contextlib import contextmanager
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from . import waiting
+from . import postgres
+from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import PQGen, PQGenConn
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from ._compat import LiteralString
+from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from ._pipeline import BasePipeline, Pipeline
+from .generators import notifies, connect, execute
+from ._encodings import pgconn_encoding
+from ._preparing import PrepareManager
+from .transaction import Transaction
+from .server_cursor import ServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn, PGresult
+ from psycopg_pool.base import BasePool
+
+
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class Notify(NamedTuple):
+ """An asynchronous notification received from the database."""
+
+ channel: str
+ """The name of the channel on which the notification was received."""
+
+ payload: str
+ """The message attached to the notification."""
+
+ pid: int
+ """The PID of the backend process which sent the notification."""
+
+
+Notify.__module__ = "psycopg"
+
+NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
+NotifyHandler: TypeAlias = Callable[[Notify], None]
+
+
+class BaseConnection(Generic[Row]):
+ """
+ Base class for different types of connections.
+
+ Share common functionalities such as access to the wrapped PGconn, but
+ allow different interfaces (sync/async).
+ """
+
+ # DBAPI2 exposed exceptions
+ Warning = e.Warning
+ Error = e.Error
+ InterfaceError = e.InterfaceError
+ DatabaseError = e.DatabaseError
+ DataError = e.DataError
+ OperationalError = e.OperationalError
+ IntegrityError = e.IntegrityError
+ InternalError = e.InternalError
+ ProgrammingError = e.ProgrammingError
+ NotSupportedError = e.NotSupportedError
+
+ # Enums useful for the connection
+ ConnStatus = pq.ConnStatus
+ TransactionStatus = pq.TransactionStatus
+
+ def __init__(self, pgconn: "PGconn"):
+ self.pgconn = pgconn
+ self._autocommit = False
+
+ # None, but set to a copy of the global adapters map as soon as requested.
+ self._adapters: Optional[AdaptersMap] = None
+
+ self._notice_handlers: List[NoticeHandler] = []
+ self._notify_handlers: List[NotifyHandler] = []
+
+ # Number of transaction blocks currently entered
+ self._num_transactions = 0
+
+ self._closed = False # closed by an explicit close()
+ self._prepared: PrepareManager = PrepareManager()
+ self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared
+
+ wself = ref(self)
+ pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+ pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
+
+ # Attribute is only set if the connection is from a pool so we can tell
+ # apart a connection in the pool too (when _pool = None)
+ self._pool: Optional["BasePool[Any]"]
+
+ self._pipeline: Optional[BasePipeline] = None
+
+ # Time after which the connection should be closed
+ self._expire_at: float
+
+ self._isolation_level: Optional[IsolationLevel] = None
+ self._read_only: Optional[bool] = None
+ self._deferrable: Optional[bool] = None
+ self._begin_statement = b""
+
+ def __del__(self) -> None:
+ # If fails on connection we might not have this attribute yet
+ if not hasattr(self, "pgconn"):
+ return
+
+ # Connection correctly closed
+ if self.closed:
+ return
+
+ # Connection in a pool so terminating with the program is normal
+ if hasattr(self, "_pool"):
+ return
+
+ warn(
+ f"connection {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the connection",
+ ResourceWarning,
+ )
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def closed(self) -> bool:
+ """`!True` if the connection is closed."""
+ return self.pgconn.status == BAD
+
+ @property
+ def broken(self) -> bool:
+ """
+ `!True` if the connection was interrupted.
+
+ A broken connection is always `closed`, but wasn't closed in a clean
+ way, such as using `close()` or a `!with` block.
+ """
+ return self.pgconn.status == BAD and not self._closed
+
+ @property
+ def autocommit(self) -> bool:
+ """The autocommit state of the connection."""
+ return self._autocommit
+
+ @autocommit.setter
+ def autocommit(self, value: bool) -> None:
+ self._set_autocommit(value)
+
+ def _set_autocommit(self, value: bool) -> None:
+ raise NotImplementedError
+
+ def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
+ yield from self._check_intrans_gen("autocommit")
+ self._autocommit = bool(value)
+
+ @property
+ def isolation_level(self) -> Optional[IsolationLevel]:
+ """
+ The isolation level of the new transactions started on the connection.
+ """
+ return self._isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._set_isolation_level(value)
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ raise NotImplementedError
+
+ def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
+ yield from self._check_intrans_gen("isolation_level")
+ self._isolation_level = IsolationLevel(value) if value is not None else None
+ self._begin_statement = b""
+
+ @property
+ def read_only(self) -> Optional[bool]:
+ """
+ The read-only state of the new transactions started on the connection.
+ """
+ return self._read_only
+
+ @read_only.setter
+ def read_only(self, value: Optional[bool]) -> None:
+ self._set_read_only(value)
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("read_only")
+ self._read_only = bool(value)
+ self._begin_statement = b""
+
+ @property
+ def deferrable(self) -> Optional[bool]:
+ """
+ The deferrable state of the new transactions started on the connection.
+ """
+ return self._deferrable
+
+ @deferrable.setter
+ def deferrable(self, value: Optional[bool]) -> None:
+ self._set_deferrable(value)
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("deferrable")
+ self._deferrable = bool(value)
+ self._begin_statement = b""
+
+ def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
+ # Raise an exception if we are in a transaction
+ status = self.pgconn.transaction_status
+ if status == IDLE and self._pipeline:
+ yield from self._pipeline._sync_gen()
+ status = self.pgconn.transaction_status
+ if status != IDLE:
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection.transaction() context in progress"
+ )
+ else:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection in transaction status "
+ f"{pq.TransactionStatus(status).name}"
+ )
+
+ @property
+ def info(self) -> ConnectionInfo:
+ """A `ConnectionInfo` attribute to inspect connection properties."""
+ return ConnectionInfo(self.pgconn)
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ self._adapters = AdaptersMap(postgres.adapters)
+
+ return self._adapters
+
+ @property
+ def connection(self) -> "BaseConnection[Row]":
+ # implement the AdaptContext protocol
+ return self
+
+ def fileno(self) -> int:
+ """Return the file descriptor of the connection.
+
+ This function allows to use the connection as file-like object in
+ functions waiting for readiness, such as the ones defined in the
+ `selectors` module.
+ """
+ return self.pgconn.socket
+
+ def cancel(self) -> None:
+ """Cancel the current operation on the connection."""
+ # No-op if the connection is closed
+ # this allows to use the method as callback handler without caring
+ # about its life.
+ if self.closed:
+ return
+
+ if self._tpc and self._tpc[1]:
+ raise e.ProgrammingError(
+ "cancel() cannot be used with a prepared two-phase transaction"
+ )
+
+ c = self.pgconn.get_cancel()
+ c.cancel()
+
+ def add_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Register a callable to be invoked when a notice message is received.
+
+ :param callback: the callback to call upon message received.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.append(callback)
+
+ def remove_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Unregister a notice message callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.remove(callback)
+
+ @staticmethod
+ def _notice_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
+ ) -> None:
+ self = wself()
+ if not (self and self._notice_handlers):
+ return
+
+ diag = e.Diagnostic(res, pgconn_encoding(self.pgconn))
+ for cb in self._notice_handlers:
+ try:
+ cb(diag)
+ except Exception as ex:
+ logger.exception("error processing notice callback '%s': %s", cb, ex)
+
+ def add_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Register a callable to be invoked whenever a notification is received.
+
+ :param callback: the callback to call upon notification received.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.append(callback)
+
+ def remove_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Unregister a notification callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.remove(callback)
+
+ @staticmethod
+ def _notify_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+ ) -> None:
+ self = wself()
+ if not (self and self._notify_handlers):
+ return
+
+ enc = pgconn_encoding(self.pgconn)
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ for cb in self._notify_handlers:
+ cb(n)
+
+ @property
+ def prepare_threshold(self) -> Optional[int]:
+ """
+ Number of times a query is executed before it is prepared.
+
+ - If it is set to 0, every query is prepared the first time it is
+ executed.
+ - If it is set to `!None`, prepared statements are disabled on the
+ connection.
+
+ Default value: 5
+ """
+ return self._prepared.prepare_threshold
+
+ @prepare_threshold.setter
+ def prepare_threshold(self, value: Optional[int]) -> None:
+ self._prepared.prepare_threshold = value
+
+ @property
+ def prepared_max(self) -> int:
+ """
+ Maximum number of prepared statements on the connection.
+
+ Default value: 100
+ """
+ return self._prepared.prepared_max
+
+ @prepared_max.setter
+ def prepared_max(self, value: int) -> None:
+ self._prepared.prepared_max = value
+
+ # Generators to perform high-level operations on the connection
+ #
+ # These operations are expressed in terms of non-blocking generators
+ # and the task of waiting when needed (when the generators yield) is left
+ # to the connections subclass, which might wait either in blocking mode
+ # or through asyncio.
+ #
+ # All these generators assume exclusive access to the connection: subclasses
+ # should have a lock and hold it before calling and consuming them.
+
+ @classmethod
+ def _connect_gen(
+ cls: Type[ConnectionType],
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ ) -> PQGenConn[ConnectionType]:
+ """Generator to connect to the database and create a new instance."""
+ pgconn = yield from connect(conninfo)
+ conn = cls(pgconn)
+ conn._autocommit = bool(autocommit)
+ return conn
+
+ def _exec_command(
+ self, command: Query, result_format: pq.Format = TEXT
+ ) -> PQGen[Optional["PGresult"]]:
+ """
+ Generator to send a command and receive the result to the backend.
+
+ Only used to implement internal commands such as "commit", with eventual
+ arguments bound client-side. The cursor can do more complex stuff.
+ """
+ self._check_connection_ok()
+
+ if isinstance(command, str):
+ command = command.encode(pgconn_encoding(self.pgconn))
+ elif isinstance(command, Composable):
+ command = command.as_bytes(self)
+
+ if self._pipeline:
+ cmd = partial(
+ self.pgconn.send_query_params,
+ command,
+ None,
+ result_format=result_format,
+ )
+ self._pipeline.command_queue.append(cmd)
+ self._pipeline.result_queue.append(None)
+ return None
+
+ self.pgconn.send_query_params(command, None, result_format=result_format)
+
+ result = (yield from execute(self.pgconn))[-1]
+ if result.status != COMMAND_OK and result.status != TUPLES_OK:
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ else:
+ raise e.InterfaceError(
+ f"unexpected result {pq.ExecStatus(result.status).name}"
+ f" from command {command.decode()!r}"
+ )
+ return result
+
+ def _check_connection_ok(self) -> None:
+ if self.pgconn.status == OK:
+ return
+
+ if self.pgconn.status == BAD:
+ raise e.OperationalError("the connection is closed")
+ raise e.InterfaceError(
+ "cannot execute operations: the connection is"
+ f" in status {self.pgconn.status}"
+ )
+
+ def _start_query(self) -> PQGen[None]:
+ """Generator to start a transaction if necessary."""
+ if self._autocommit:
+ return
+
+ if self.pgconn.transaction_status != IDLE:
+ return
+
+ yield from self._exec_command(self._get_tx_start_command())
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _get_tx_start_command(self) -> bytes:
+ if self._begin_statement:
+ return self._begin_statement
+
+ parts = [b"BEGIN"]
+
+ if self.isolation_level is not None:
+ val = IsolationLevel(self.isolation_level)
+ parts.append(b"ISOLATION LEVEL")
+ parts.append(val.name.replace("_", " ").encode())
+
+ if self.read_only is not None:
+ parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
+
+ if self.deferrable is not None:
+ parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
+
+ self._begin_statement = b" ".join(parts)
+ return self._begin_statement
+
+ def _commit_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.commit()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit commit() forbidden within a Transaction "
+ "context. (Transaction will be automatically committed "
+ "on successful exit from context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "commit() cannot be used during a two-phase transaction"
+ )
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"COMMIT")
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _rollback_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.rollback()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit rollback() forbidden within a Transaction "
+ "context. (Either raise Rollback() or allow "
+ "an exception to propagate out of the context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "rollback() cannot be used during a two-phase transaction"
+ )
+
+ # Get out of a "pipeline aborted" state
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"ROLLBACK")
+ self._prepared.clear()
+ for cmd in self._prepared.get_maintenance_commands():
+ yield from self._exec_command(cmd)
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+ """
+ Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
+
+ The argument types and constraints are explained in
+ :ref:`two-phase-commit`.
+
+ The values passed to the method will be available on the returned
+ object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
+ """
+ self._check_tpc()
+ return Xid.from_parts(format_id, gtrid, bqual)
+
+ def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+ self._check_tpc()
+
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self.pgconn.transaction_status != IDLE:
+ raise e.ProgrammingError(
+ "can't start two-phase transaction: connection in status"
+ f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
+ )
+
+ if self._autocommit:
+ raise e.ProgrammingError(
+ "can't use two-phase transactions in autocommit mode"
+ )
+
+ self._tpc = (xid, False)
+ yield from self._exec_command(self._get_tx_start_command())
+
+ def _tpc_prepare_gen(self) -> PQGen[None]:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' must be called inside a two-phase transaction"
+ )
+ if self._tpc[1]:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' cannot be used during a prepared two-phase transaction"
+ )
+ xid = self._tpc[0]
+ self._tpc = (xid, True)
+ yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _tpc_finish_gen(
+ self, action: LiteralString, xid: Union[Xid, str, None]
+ ) -> PQGen[None]:
+ fname = f"tpc_{action.lower()}()"
+ if xid is None:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} without xid must must be"
+ " called inside a two-phase transaction"
+ )
+ xid = self._tpc[0]
+ else:
+ if self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} with xid must must be called"
+ " outside a two-phase transaction"
+ )
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self._tpc and not self._tpc[1]:
+ meth: Callable[[], PQGen[None]]
+ meth = getattr(self, f"_{action.lower()}_gen")
+ self._tpc = None
+ yield from meth()
+ else:
+ yield from self._exec_command(
+ SQL("{} PREPARED {}").format(SQL(action), str(xid))
+ )
+ self._tpc = None
+
+ def _check_tpc(self) -> None:
+ """Raise NotSupportedError if TPC is not supported."""
+ # TPC supported on every supported PostgreSQL version.
+ pass
+
+
+class Connection(BaseConnection[Row]):
+ """
+ Wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[Cursor[Row]]
+ server_cursor_factory: Type[ServerCursor[Row]]
+ row_factory: RowFactory[Row]
+ _pipeline: Optional[Pipeline]
+ _Self = TypeVar("_Self", bound="Connection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = threading.Lock()
+ self.cursor_factory = Cursor
+ self.server_cursor_factory = ServerCursor
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: Optional[RowFactory[Row]] = None,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Any,
+ ) -> "Connection[Any]":
+ """
+ Connect to a database server and return a new `Connection` instance.
+ """
+ params = cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ self.close()
+
+ @classmethod
+ def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ :param conninfo: Connection string as received by `~Connection.connect()`.
+ :param kwargs: Overriding connection arguments as received by `!connect()`.
+ :return: Connection arguments merged and eventually modified, in a
+ format similar to `~conninfo.conninfo_to_dict()`.
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ return params
+
+ def close(self) -> None:
+ """Close the database connection."""
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> Cursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+ ) -> Cursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: RowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[RowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[Cursor[Any], ServerCursor[Any]]:
+ """
+ Return a new cursor to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[Cursor[Any], ServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> Cursor[Row]:
+ """Execute a query and return a cursor to read its results."""
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def commit(self) -> None:
+ """Commit any pending transaction to the database."""
+ with self.lock:
+ self.wait(self._commit_gen())
+
+ def rollback(self) -> None:
+ """Roll back to the start of any pending transaction."""
+ with self.lock:
+ self.wait(self._rollback_gen())
+
+ @contextmanager
+ def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> Iterator[Transaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :param savepoint_name: Name of the savepoint used to manage a nested
+ transaction. If `!None`, one will be chosen automatically.
+ :param force_rollback: Roll back the transaction at the end of the
+ block even if there were no error (e.g. to try a no-op process).
+ :rtype: Transaction
+ """
+ tx = Transaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ with tx:
+ yield tx
+
+ def notifies(self) -> Generator[Notify, None, None]:
+ """
+ Yield `Notify` objects as soon as they are received from the database.
+ """
+ while True:
+ with self.lock:
+ try:
+ ns = self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @contextmanager
+ def pipeline(self) -> Iterator[Pipeline]:
+ """Switch the connection into pipeline mode."""
+ with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = Pipeline(self)
+
+ try:
+ with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ """
+ Consume a generator operating on the connection.
+
+ The function must be used on generators that don't change connection
+ fd (i.e. not on connect and reset).
+ """
+ try:
+ return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except KeyboardInterrupt:
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # the connection will remain stuck in ACTIVE state.
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ """Consume a connection generator."""
+ return waiting.wait_conn(gen, timeout=timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ with self.lock:
+ self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ with self.lock:
+ self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_deferrable_gen(value))
+
+ def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ """
+ Begin a TPC transaction with the given transaction ID `!xid`.
+ """
+ with self.lock:
+ self.wait(self._tpc_begin_gen(xid))
+
+ def tpc_prepare(self) -> None:
+ """
+ Perform the first phase of a transaction started with `tpc_begin()`.
+ """
+ try:
+ with self.lock:
+ self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Commit a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("COMMIT", xid))
+
+ def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Roll back a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("ROLLBACK", xid))
+
+ def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ cur.execute(Xid._get_recover_query())
+ res = cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py
new file mode 100644
index 0000000..aa02dc0
--- /dev/null
+++ b/psycopg/psycopg/connection_async.py
@@ -0,0 +1,436 @@
+"""
+psycopg async connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import asyncio
+import logging
+from types import TracebackType
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from . import waiting
+from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from ._tpc import Xid
+from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from ._pipeline import AsyncPipeline
+from ._encodings import pgconn_encoding
+from .connection import BaseConnection, CursorRow, Notify
+from .generators import notifies
+from .transaction import AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class AsyncConnection(BaseConnection[Row]):
+ """
+ Asynchronous wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[AsyncCursor[Row]]
+ server_cursor_factory: Type[AsyncServerCursor[Row]]
+ row_factory: AsyncRowFactory[Row]
+ _pipeline: Optional[AsyncPipeline]
+ _Self = TypeVar("_Self", bound="AsyncConnection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = asyncio.Lock()
+ self.cursor_factory = AsyncCursor
+ self.server_cursor_factory = AsyncServerCursor
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ context: Optional[AdaptContext] = None,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ **kwargs: Any,
+ ) -> "AsyncConnection[Any]":
+
+ if sys.platform == "win32":
+ loop = asyncio.get_running_loop()
+ if isinstance(loop, asyncio.ProactorEventLoop):
+ raise e.InterfaceError(
+ "Psycopg cannot use the 'ProactorEventLoop' to run in async"
+ " mode. Please use a compatible event loop, for instance by"
+ " setting 'asyncio.set_event_loop_policy"
+ "(WindowsSelectorEventLoopPolicy())'"
+ )
+
+ params = await cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = await cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ await self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ await self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ await self.close()
+
+ @classmethod
+ async def _get_connection_params(
+ cls, conninfo: str, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ .. versionchanged:: 3.1
+ Unlike the sync counterpart, perform non-blocking address
+ resolution and populate the ``hostaddr`` connection parameter,
+ unless the user has provided one themselves. See
+ `~psycopg._dns.resolve_hostaddr_async()` for details.
+
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ # Resolve host addresses in non-blocking way
+ params = await resolve_hostaddr_async(params)
+
+ return params
+
+ async def close(self) -> None:
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
+ ) -> AsyncCursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: AsyncRowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[AsyncRowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
+ """
+ Return a new `AsyncCursor` to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ async def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> AsyncCursor[Row]:
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return await cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def commit(self) -> None:
+ async with self.lock:
+ await self.wait(self._commit_gen())
+
+ async def rollback(self) -> None:
+ async with self.lock:
+ await self.wait(self._rollback_gen())
+
+ @asynccontextmanager
+ async def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> AsyncIterator[AsyncTransaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :rtype: AsyncTransaction
+ """
+ tx = AsyncTransaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ async with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ async with tx:
+ yield tx
+
+ async def notifies(self) -> AsyncGenerator[Notify, None]:
+ while True:
+ async with self.lock:
+ try:
+ ns = await self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @asynccontextmanager
+ async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
+ """Context manager to switch the connection into pipeline mode."""
+ async with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = AsyncPipeline(self)
+
+ try:
+ async with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ async with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ async def wait(self, gen: PQGen[RV]) -> RV:
+ try:
+ return await waiting.wait_async(gen, self.pgconn.socket)
+ except KeyboardInterrupt:
+ # TODO: this doesn't seem to work as it does for sync connections
+ # see tests/test_concurrency_async.py::test_ctrl_c
+ # In the test, the code doesn't reach this branch.
+
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # otherwise the connection will be stuck in ACTIVE state
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ await waiting.wait_async(gen, self.pgconn.socket)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ return await waiting.wait_conn_async(gen, timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ self._no_set_async("autocommit")
+
+ async def set_autocommit(self, value: bool) -> None:
+ """Async version of the `~Connection.autocommit` setter."""
+ async with self.lock:
+ await self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._no_set_async("isolation_level")
+
+ async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ """Async version of the `~Connection.isolation_level` setter."""
+ async with self.lock:
+ await self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ self._no_set_async("read_only")
+
+ async def set_read_only(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.read_only` setter."""
+ async with self.lock:
+ await self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ self._no_set_async("deferrable")
+
+ async def set_deferrable(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.deferrable` setter."""
+ async with self.lock:
+ await self.wait(self._set_deferrable_gen(value))
+
+ def _no_set_async(self, attribute: str) -> None:
+ raise AttributeError(
+ f"'the {attribute!r} property is read-only on async connections:"
+ f" please use 'await .set_{attribute}()' instead."
+ )
+
+ async def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_begin_gen(xid))
+
+ async def tpc_prepare(self) -> None:
+ try:
+ async with self.lock:
+ await self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("commit", xid))
+
+ async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("rollback", xid))
+
+ async def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ await cur.execute(Xid._get_recover_query())
+ res = await cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ await self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py
new file mode 100644
index 0000000..3b21f83
--- /dev/null
+++ b/psycopg/psycopg/conninfo.py
@@ -0,0 +1,378 @@
+"""
+Functions to manipulate conninfo strings
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import socket
+import asyncio
+from typing import Any, Dict, List, Optional
+from pathlib import Path
+from datetime import tzinfo
+from functools import lru_cache
+from ipaddress import ip_address
+
+from . import pq
+from . import errors as e
+from ._tz import get_tzinfo
+from ._encodings import pgconn_encoding
+
+
+def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+ """
+ Merge a string and keyword params into a single conninfo string.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: A connection string valid for PostgreSQL, with the `!kwargs`
+ parameters merged.
+
+ Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
+ conninfo string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ if not conninfo and not kwargs:
+ return ""
+
+ # If no kwarg specified don't mung the conninfo but check if it's correct.
+ # Make sure to return a string, not a subtype, to avoid making Liskov sad.
+ if not kwargs:
+ _parse_conninfo(conninfo)
+ return str(conninfo)
+
+ # Override the conninfo with the parameters
+ # Drop the None arguments
+ kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
+
+ if conninfo:
+ tmp = conninfo_to_dict(conninfo)
+ tmp.update(kwargs)
+ kwargs = tmp
+
+ conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
+
+ # Verify the result is valid
+ _parse_conninfo(conninfo)
+
+ return conninfo
+
+
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+ """
+ Convert the `!conninfo` string into a dictionary of parameters.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: Dictionary with the parameters parsed from `!conninfo` and
+ `!kwargs`.
+
+ Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
+ string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ opts = _parse_conninfo(conninfo)
+ rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+ for k, v in kwargs.items():
+ if v is not None:
+ rv[k] = v
+ return rv
+
+
+def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+ """
+ Verify that `!conninfo` is a valid connection string.
+
+ Raise ProgrammingError if the string is not valid.
+
+ Return the result of pq.Conninfo.parse() on success.
+ """
+ try:
+ return pq.Conninfo.parse(conninfo.encode())
+ except e.OperationalError as ex:
+ raise e.ProgrammingError(str(ex))
+
+
+re_escape = re.compile(r"([\\'])")
+re_space = re.compile(r"\s")
+
+
+def _param_escape(s: str) -> str:
+ """
+ Apply the escaping rule required by PQconnectdb
+ """
+ if not s:
+ return "''"
+
+ s = re_escape.sub(r"\\\1", s)
+ if re_space.search(s):
+ s = "'" + s + "'"
+
+ return s
+
+
+class ConnectionInfo:
+ """Allow access to information about the connection."""
+
+ __module__ = "psycopg"
+
+ def __init__(self, pgconn: pq.abc.PGconn):
+ self.pgconn = pgconn
+
+ @property
+ def vendor(self) -> str:
+ """A string representing the database vendor connected to."""
+ return "PostgreSQL"
+
+ @property
+ def host(self) -> str:
+ """The server host name of the active connection. See :pq:`PQhost()`."""
+ return self._get_pgconn_attr("host")
+
+ @property
+ def hostaddr(self) -> str:
+ """The server IP address of the connection. See :pq:`PQhostaddr()`."""
+ return self._get_pgconn_attr("hostaddr")
+
+ @property
+ def port(self) -> int:
+ """The port of the active connection. See :pq:`PQport()`."""
+ return int(self._get_pgconn_attr("port"))
+
+ @property
+ def dbname(self) -> str:
+ """The database name of the connection. See :pq:`PQdb()`."""
+ return self._get_pgconn_attr("db")
+
+ @property
+ def user(self) -> str:
+ """The user name of the connection. See :pq:`PQuser()`."""
+ return self._get_pgconn_attr("user")
+
+ @property
+ def password(self) -> str:
+ """The password of the connection. See :pq:`PQpass()`."""
+ return self._get_pgconn_attr("password")
+
+ @property
+ def options(self) -> str:
+ """
+ The command-line options passed in the connection request.
+ See :pq:`PQoptions`.
+ """
+ return self._get_pgconn_attr("options")
+
+ def get_parameters(self) -> Dict[str, str]:
+ """Return the connection parameters values.
+
+ Return all the parameters set to a non-default value, which might come
+ either from the connection string and parameters passed to
+ `~Connection.connect()` or from environment variables. The password
+ is never returned (you can read it using the `password` attribute).
+ """
+ pyenc = self.encoding
+
+ # Get the known defaults to avoid reporting them
+ defaults = {
+ i.keyword: i.compiled
+ for i in pq.Conninfo.get_defaults()
+ if i.compiled is not None
+ }
+ # Not returned by the libq. Bug? Bet we're using SSH.
+ defaults.setdefault(b"channel_binding", b"prefer")
+ defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
+
+ return {
+ i.keyword.decode(pyenc): i.val.decode(pyenc)
+ for i in self.pgconn.info
+ if i.val is not None
+ and i.keyword != b"password"
+ and i.val != defaults.get(i.keyword)
+ }
+
+ @property
+ def dsn(self) -> str:
+ """Return the connection string to connect to the database.
+
+ The string contains all the parameters set to a non-default value,
+ which might come either from the connection string and parameters
+ passed to `~Connection.connect()` or from environment variables. The
+ password is never returned (you can read it using the `password`
+ attribute).
+ """
+ return make_conninfo(**self.get_parameters())
+
+ @property
+ def status(self) -> pq.ConnStatus:
+ """The status of the connection. See :pq:`PQstatus()`."""
+ return pq.ConnStatus(self.pgconn.status)
+
+ @property
+ def transaction_status(self) -> pq.TransactionStatus:
+ """
+ The current in-transaction status of the session.
+ See :pq:`PQtransactionStatus()`.
+ """
+ return pq.TransactionStatus(self.pgconn.transaction_status)
+
+ @property
+ def pipeline_status(self) -> pq.PipelineStatus:
+ """
+ The current pipeline status of the client.
+ See :pq:`PQpipelineStatus()`.
+ """
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ def parameter_status(self, param_name: str) -> Optional[str]:
+ """
+ Return a parameter setting of the connection.
+
+ Return `None` is the parameter is unknown.
+ """
+ res = self.pgconn.parameter_status(param_name.encode(self.encoding))
+ return res.decode(self.encoding) if res is not None else None
+
+ @property
+ def server_version(self) -> int:
+ """
+ An integer representing the server version. See :pq:`PQserverVersion()`.
+ """
+ return self.pgconn.server_version
+
+ @property
+ def backend_pid(self) -> int:
+ """
+ The process ID (PID) of the backend process handling this connection.
+ See :pq:`PQbackendPID()`.
+ """
+ return self.pgconn.backend_pid
+
+ @property
+ def error_message(self) -> str:
+ """
+ The error message most recently generated by an operation on the connection.
+ See :pq:`PQerrorMessage()`.
+ """
+ return self._get_pgconn_attr("error_message")
+
+ @property
+ def timezone(self) -> tzinfo:
+ """The Python timezone info of the connection's timezone."""
+ return get_tzinfo(self.pgconn)
+
+ @property
+ def encoding(self) -> str:
+ """The Python codec name of the connection's client encoding."""
+ return pgconn_encoding(self.pgconn)
+
+ def _get_pgconn_attr(self, name: str) -> str:
+ value: bytes = getattr(self.pgconn, name)
+ return value.decode(self.encoding)
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups, eventually removing hosts that are
+ not resolved, keeping the lists of hosts and ports consistent.
+
+ Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+ host resolve, inconsistent lists length).
+ """
+ hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+ if hostaddr_arg:
+ # Already resolved
+ return params
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ if not host_arg:
+ # Nothing to resolve
+ return params
+
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",") if port_arg else []
+ default_port = "5432"
+
+ if len(ports_in) == 1:
+ # If only one port is specified, the libpq will apply it to all
+ # the hosts, so don't mangle it.
+ default_port = ports_in.pop()
+
+ elif len(ports_in) > 1:
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+ ports_out = []
+
+ hosts_out = []
+ hostaddr_out = []
+ loop = asyncio.get_running_loop()
+ for i, host in enumerate(hosts_in):
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path
+ hosts_out.append(host)
+ hostaddr_out.append("")
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ # If the host is already an ip address don't try to resolve it
+ if is_ip_address(host):
+ hosts_out.append(host)
+ hostaddr_out.append(host)
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ try:
+ port = ports_in[i] if ports_in else default_port
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ except OSError as ex:
+ last_exc = ex
+ else:
+ for item in ans:
+ hosts_out.append(host)
+ hostaddr_out.append(item[4][0])
+ if ports_in:
+ ports_out.append(ports_in[i])
+
+ # Throw an exception if no host could be resolved
+ if not hosts_out:
+ raise e.OperationalError(str(last_exc))
+
+ out = params.copy()
+ out["host"] = ",".join(hosts_out)
+ out["hostaddr"] = ",".join(hostaddr_out)
+ if ports_in:
+ out["port"] = ",".join(ports_out)
+
+ return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py
new file mode 100644
index 0000000..7514306
--- /dev/null
+++ b/psycopg/psycopg/copy.py
@@ -0,0 +1,904 @@
+"""
+psycopg copy support
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import queue
+import struct
+import asyncio
+import threading
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
+from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import Buffer, ConnectionType, PQGen, Transformer
+from ._compat import create_task
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding
+from .generators import copy_from, copy_to, copy_end
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
+
+class BaseCopy(Generic[ConnectionType]):
+ """
+ Base implementation for the copy user interface.
+
+ Two subclasses expose real methods with the sync/async differences.
+
+ The difference between the text and binary format is managed by two
+ different `Formatter` subclasses.
+
+ Writing (the I/O part) is implemented in the subclasses by a `Writer` or
+ `AsyncWriter` instance. Normally writing implies sending copy data to a
+ database, but a different writer might be chosen, e.g. to stream data into
+ a file for later use.
+ """
+
+ _Self = TypeVar("_Self", bound="BaseCopy[Any]")
+
+ formatter: "Formatter"
+
+ def __init__(
+ self,
+ cursor: "BaseCursor[ConnectionType, Any]",
+ *,
+ binary: Optional[bool] = None,
+ ):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ result = cursor.pgresult
+ if result:
+ self._direction = result.status
+ if self._direction != COPY_IN and self._direction != COPY_OUT:
+ raise e.ProgrammingError(
+ "the cursor should have performed a COPY operation;"
+ f" its status is {pq.ExecStatus(self._direction).name} instead"
+ )
+ else:
+ self._direction = COPY_IN
+
+ if binary is None:
+ binary = bool(result and result.binary_tuples)
+
+ tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+ if binary:
+ self.formatter = BinaryFormatter(tx)
+ else:
+ self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+
+ self._finished = False
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def _enter(self) -> None:
+ if self._finished:
+ raise TypeError("copy blocks can be used only once")
+
+ def set_types(self, types: Sequence[Union[int, str]]) -> None:
+ """
+ Set the types expected in a COPY operation.
+
+ The types must be specified as a sequence of oid or PostgreSQL type
+ names (e.g. ``int4``, ``timestamptz[]``).
+
+ This operation overcomes the lack of metadata returned by PostgreSQL
+ when a COPY operation begins:
+
+ - On :sql:`COPY TO`, `!set_types()` allows to specify what types the
+ operation returns. If `!set_types()` is not used, the data will be
+ returned as unparsed strings or bytes instead of Python objects.
+
+ - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
+ database expects. This is especially useful in binary copy, because
+ PostgreSQL will apply no cast rule.
+
+ """
+ registry = self.cursor.adapters.types
+ oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
+
+ if self._direction == COPY_IN:
+ self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
+ else:
+ self.formatter.transformer.set_loader_types(oids, self.formatter.format)
+
+ # High level copy protocol generators (state change of the Copy object)
+
+ def _read_gen(self) -> PQGen[Buffer]:
+ if self._finished:
+ return memoryview(b"")
+
+ res = yield from copy_from(self._pgconn)
+ if isinstance(res, memoryview):
+ return res
+
+ # res is the final PGresult
+ self._finished = True
+
+ # This result is a COMMAND_OK which has info about the number of rows
+ # returned, but not about the columns, which is instead an information
+ # that was received on the COPY_OUT result at the beginning of COPY.
+ # So, don't replace the results in the cursor, just update the rowcount.
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+ return memoryview(b"")
+
+ def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
+ data = yield from self._read_gen()
+ if not data:
+ return None
+
+ row = self.formatter.parse_row(data)
+ if row is None:
+ # Get the final result to finish the copy operation
+ yield from self._read_gen()
+ self._finished = True
+ return None
+
+ return row
+
+ def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ if not exc:
+ return
+
+ if self._pgconn.transaction_status != ACTIVE:
+ # The server has already finished to send copy data. The connection
+ # is already in a good state.
+ return
+
+ # Throw a cancel to the server, then consume the rest of the copy data
+ # (which might or might not have been already transferred entirely to
+ # the client, so we won't necessary see the exception associated with
+ # canceling).
+ self.connection.cancel()
+ try:
+ while (yield from self._read_gen()):
+ pass
+ except e.QueryCanceled:
+ pass
+
+
+class Copy(BaseCopy["Connection[Any]"]):
+ """Manage a :sql:`COPY` operation.
+
+ :param cursor: the cursor where the operation is performed.
+ :param binary: if `!True`, write binary format.
+ :param writer: the object to write to destination. If not specified, write
+ to the `!cursor` connection.
+
+ Choosing `!binary` is not necessary if the cursor has executed a
+ :sql:`COPY` operation, because the operation result describes the format
+ too. The parameter is useful when a `!Copy` object is created manually and
+ no operation is performed on the cursor, such as when using ``writer=``\\
+ `~psycopg.copy.FileWriter`.
+
+ """
+
+ __module__ = "psycopg"
+
+ writer: "Writer"
+
+ def __init__(
+ self,
+ cursor: "Cursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["Writer"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+ if not writer:
+ writer = LibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.finish(exc_val)
+
+ # End user sync interface
+
+ def __iter__(self) -> Iterator[Buffer]:
+ """Implement block-by-block iteration on :sql:`COPY TO`."""
+ while True:
+ data = self.read()
+ if not data:
+ break
+ yield data
+
+ def read(self) -> Buffer:
+ """
+ Read an unparsed row after a :sql:`COPY TO` operation.
+
+ Return an empty string when the data is finished.
+ """
+ return self.connection.wait(self._read_gen())
+
+ def rows(self) -> Iterator[Tuple[Any, ...]]:
+ """
+ Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ while True:
+ record = self.read_row()
+ if record is None:
+ break
+ yield record
+
+ def read_row(self) -> Optional[Tuple[Any, ...]]:
+ """
+ Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+ Return `!None` when the data is finished.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ return self.connection.wait(self._read_row_gen())
+
+ def write(self, buffer: Union[Buffer, str]) -> None:
+ """
+ Write a block of data to a table after a :sql:`COPY FROM` operation.
+
+ If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
+ text mode it can be either `!bytes` or `!str`.
+ """
+ data = self.formatter.write(buffer)
+ if data:
+ self._write(data)
+
+ def write_row(self, row: Sequence[Any]) -> None:
+ """Write a record to a table after a :sql:`COPY FROM` operation."""
+ data = self.formatter.write_row(row)
+ if data:
+ self._write(data)
+
+ def finish(self, exc: Optional[BaseException]) -> None:
+ """Terminate the copy operation and free the resources allocated.
+
+ You shouldn't need to call this function yourself: it is usually called
+ by exit. It is available if, despite what is documented, you end up
+ using the `Copy` object outside a block.
+ """
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ self._write(data)
+ self.writer.finish(exc)
+ self._finished = True
+ else:
+ self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class Writer(ABC):
+ """
+ A class to write copy data somewhere.
+ """
+
+ @abstractmethod
+ def write(self, data: Buffer) -> None:
+ """
+ Write some data to destination.
+ """
+ ...
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ """
+ Called when write operations are finished.
+
+ If operations finished with an error, it will be passed to ``exc``.
+ """
+ pass
+
+
+class LibpqWriter(Writer):
+ """
+ A `Writer` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class QueuedLibpqDriver(LibpqWriter):
+ """
+ A writer using a buffer to queue data to write to a Postgres database.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
+
+ The function is designed to be run in a separate thread.
+ """
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
+
+ def write(self, data: Buffer) -> None:
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ self._queue.put(b"")
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
+ super().finish(exc)
+
+
+class FileWriter(Writer):
+ """
+ A `Writer` to write copy data to a file-like object.
+
+ :param file: the file where to write copy data. It must be open for writing
+ in binary mode.
+ """
+
+ def __init__(self, file: IO[bytes]):
+ self.file = file
+
+ def write(self, data: Buffer) -> None:
+ self.file.write(data) # type: ignore[arg-type]
+
+
+class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
+ """Manage an asynchronous :sql:`COPY` operation."""
+
+ __module__ = "psycopg"
+
+ writer: "AsyncWriter"
+
+ def __init__(
+ self,
+ cursor: "AsyncCursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["AsyncWriter"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+
+ if not writer:
+ writer = AsyncLibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.finish(exc_val)
+
+ async def __aiter__(self) -> AsyncIterator[Buffer]:
+ while True:
+ data = await self.read()
+ if not data:
+ break
+ yield data
+
+ async def read(self) -> Buffer:
+ return await self.connection.wait(self._read_gen())
+
+ async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
+ while True:
+ record = await self.read_row()
+ if record is None:
+ break
+ yield record
+
+ async def read_row(self) -> Optional[Tuple[Any, ...]]:
+ return await self.connection.wait(self._read_row_gen())
+
+ async def write(self, buffer: Union[Buffer, str]) -> None:
+ data = self.formatter.write(buffer)
+ if data:
+ await self._write(data)
+
+ async def write_row(self, row: Sequence[Any]) -> None:
+ data = self.formatter.write_row(row)
+ if data:
+ await self._write(data)
+
+ async def finish(self, exc: Optional[BaseException]) -> None:
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ await self._write(data)
+ await self.writer.finish(exc)
+ self._finished = True
+ else:
+ await self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class AsyncWriter(ABC):
+ """
+ A class to write copy data somewhere (for async connections).
+ """
+
+ @abstractmethod
+ async def write(self, data: Buffer) -> None:
+ ...
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ pass
+
+
+class AsyncLibpqWriter(AsyncWriter):
+ """
+ An `AsyncWriter` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ async def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
+ """
+ An `AsyncWriter` using a buffer to queue data to write.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[asyncio.Future[None]] = None
+
+ async def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value.
+
+ The function is designed to be run in a separate task.
+ """
+ while True:
+ data = await self._queue.get()
+ if not data:
+ break
+ await self.connection.wait(copy_to(self._pgconn, data))
+
+ async def write(self, data: Buffer) -> None:
+ if not self._worker:
+ self._worker = create_task(self.worker())
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ await self._queue.put(b"")
+
+ if self._worker:
+ await asyncio.gather(self._worker)
+ self._worker = None # break reference loops if any
+
+ await super().finish(exc)
+
+
+class Formatter(ABC):
+ """
+ A class which understand a copy format (text, binary).
+ """
+
+ format: pq.Format
+
+ def __init__(self, transformer: Transformer):
+ self.transformer = transformer
+ self._write_buffer = bytearray()
+ self._row_mode = False # true if the user is using write_row()
+
+ @abstractmethod
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ ...
+
+ @abstractmethod
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def end(self) -> Buffer:
+ ...
+
+
+class TextFormatter(Formatter):
+
+ format = TEXT
+
+ def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+ super().__init__(transformer)
+ self._encoding = encoding
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if data:
+ return parse_row_text(data, self.transformer)
+ else:
+ return None
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ format_row_text(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ return data.encode(self._encoding)
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+class BinaryFormatter(Formatter):
+
+ format = BINARY
+
+ def __init__(self, transformer: Transformer):
+ super().__init__(transformer)
+ self._signature_sent = False
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if not self._signature_sent:
+ if data[: len(_binary_signature)] != _binary_signature:
+ raise e.DataError(
+ "binary copy doesn't start with the expected signature"
+ )
+ self._signature_sent = True
+ data = data[len(_binary_signature) :]
+
+ elif data == _binary_trailer:
+ return None
+
+ return parse_row_binary(data, self.transformer)
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._signature_sent = True
+
+ format_row_binary(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._write_buffer += _binary_trailer
+
+ elif self._row_mode:
+ # if we have sent data already, we have sent the signature
+ # too (either with the first row, or we assume that in
+ # block mode the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ self._write_buffer += _binary_trailer
+
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ raise TypeError("cannot copy str data in binary mode: use bytes instead")
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+def _format_row_text(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for copy."""
+ if out is None:
+ out = bytearray()
+
+ if not row:
+ out += b"\n"
+ return out
+
+ for item in row:
+ if item is not None:
+ dumper = tx.get_dumper(item, PY_TEXT)
+ b = dumper.dump(item)
+ out += _dump_re.sub(_dump_sub, b)
+ else:
+ out += rb"\N"
+ out += b"\t"
+
+ out[-1:] = b"\n"
+ return out
+
+
+def _format_row_binary(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for binary copy."""
+ if out is None:
+ out = bytearray()
+
+ out += _pack_int2(len(row))
+ adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
+ for b in adapted:
+ if b is not None:
+ out += _pack_int4(len(b))
+ out += b
+ else:
+ out += _binary_null
+
+ return out
+
+
+def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ fields = data.split(b"\t")
+ fields[-1] = fields[-1][:-1] # drop \n
+ row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
+ return tx.load_sequence(row)
+
+
+def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ row: List[Optional[Buffer]] = []
+ nfields = _unpack_int2(data, 0)[0]
+ pos = 2
+ for i in range(nfields):
+ length = _unpack_int4(data, pos)[0]
+ pos += 4
+ if length >= 0:
+ row.append(data[pos : pos + length])
+ pos += length
+ else:
+ row.append(None)
+
+ return tx.load_sequence(row)
+
+
+_pack_int2 = struct.Struct("!h").pack
+_pack_int4 = struct.Struct("!i").pack
+_unpack_int2 = struct.Struct("!h").unpack_from
+_unpack_int4 = struct.Struct("!i").unpack_from
+
+_binary_signature = (
+ b"PGCOPY\n\xff\r\n\0" # Signature
+ b"\x00\x00\x00\x00" # flags
+ b"\x00\x00\x00\x00" # extra length
+)
+_binary_trailer = b"\xff\xff"
+_binary_null = b"\xff\xff\xff\xff"
+
+_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_dump_repl = {
+ b"\b": b"\\b",
+ b"\t": b"\\t",
+ b"\n": b"\\n",
+ b"\v": b"\\v",
+ b"\f": b"\\f",
+ b"\r": b"\\r",
+ b"\\": b"\\\\",
+}
+
+
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+_load_re = re.compile(b"\\\\[btnvfr\\\\]")
+_load_repl = {v: k for k, v in _dump_repl.items()}
+
+
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ format_row_text = _psycopg.format_row_text
+ format_row_binary = _psycopg.format_row_binary
+ parse_row_text = _psycopg.parse_row_text
+ parse_row_binary = _psycopg.parse_row_binary
+
+else:
+ format_row_text = _format_row_text
+ format_row_binary = _format_row_binary
+ parse_row_text = _parse_row_text
+ parse_row_binary = _parse_row_binary
diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py
new file mode 100644
index 0000000..323903a
--- /dev/null
+++ b/psycopg/psycopg/crdb/__init__.py
@@ -0,0 +1,19 @@
+"""
+CockroachDB support package.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from . import _types
+from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
+
+adapters = _types.adapters # exposed by the package
+connect = CrdbConnection.connect
+
+_types.register_crdb_adapters(adapters)
+
+__all__ = [
+ "AsyncCrdbConnection",
+ "CrdbConnection",
+ "CrdbConnectionInfo",
+]
diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py
new file mode 100644
index 0000000..5311e05
--- /dev/null
+++ b/psycopg/psycopg/crdb/_types.py
@@ -0,0 +1,163 @@
+"""
+Types configuration specific for CockroachDB.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from enum import Enum
+from .._typeinfo import TypeInfo, TypesRegistry
+
+from ..abc import AdaptContext, NoneType
+from ..postgres import TEXT_OID
+from .._adapters_map import AdaptersMap
+from ..types.enum import EnumDumper, EnumBinaryDumper
+from ..types.none import NoneDumper
+
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+
+class CrdbEnumDumper(EnumDumper):
+ oid = TEXT_OID
+
+
+class CrdbEnumBinaryDumper(EnumBinaryDumper):
+ oid = TEXT_OID
+
+
+class CrdbNoneDumper(NoneDumper):
+ oid = TEXT_OID
+
+
+def register_postgres_adapters(context: AdaptContext) -> None:
+ # Same adapters used by PostgreSQL, or a good starting point for customization
+
+ from ..types import array, bool, composite, datetime
+ from ..types import numeric, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
+
+
+def register_crdb_adapters(context: AdaptContext) -> None:
+ from .. import dbapi20
+ from ..types import array
+
+ register_postgres_adapters(context)
+
+ # String must come after enum to map text oid -> string dumper
+ register_crdb_enum_adapters(context)
+ register_crdb_string_adapters(context)
+ register_crdb_json_adapters(context)
+ register_crdb_net_adapters(context)
+ register_crdb_none_adapters(context)
+
+ dbapi20.register_dbapi20_adapters(adapters)
+
+ array.register_all_arrays(adapters)
+
+
+def register_crdb_string_adapters(context: AdaptContext) -> None:
+ from ..types import string
+
+ # Dump strings with text oid instead of unknown.
+ # Unlike PostgreSQL, CRDB seems able to cast text to most types.
+ context.adapters.register_dumper(str, string.StrDumper)
+ context.adapters.register_dumper(str, string.StrBinaryDumper)
+
+
+def register_crdb_enum_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
+ context.adapters.register_dumper(Enum, CrdbEnumDumper)
+
+
+def register_crdb_json_adapters(context: AdaptContext) -> None:
+ from ..types import json
+
+ adapters = context.adapters
+
+ # CRDB doesn't have json/jsonb: both names map to the jsonb oid
+ adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Json, json.JsonbDumper)
+
+ adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Jsonb, json.JsonbDumper)
+
+ adapters.register_loader("json", json.JsonLoader)
+ adapters.register_loader("jsonb", json.JsonbLoader)
+ adapters.register_loader("json", json.JsonBinaryLoader)
+ adapters.register_loader("jsonb", json.JsonbBinaryLoader)
+
+
+def register_crdb_net_adapters(context: AdaptContext) -> None:
+ from ..types import net
+
+ adapters = context.adapters
+
+ adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper(None, net.InetBinaryDumper)
+ adapters.register_loader("inet", net.InetLoader)
+ adapters.register_loader("inet", net.InetBinaryLoader)
+
+
+def register_crdb_none_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, CrdbNoneDumper)
+
+
+for t in [
+ TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
+ TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
+ TypeInfo('"char"', 18, 1002), # special case, not generated
+ # autogenerated: start
+ # Generated from CockroachDB 22.1.0
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("unknown", 705, 0),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ # autogenerated: end
+]:
+ types.add(t)
diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py
new file mode 100644
index 0000000..6e79ed1
--- /dev/null
+++ b/psycopg/psycopg/crdb/connection.py
@@ -0,0 +1,186 @@
+"""
+CockroachDB-specific connections.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import re
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
+
+from .. import errors as e
+from ..abc import AdaptContext
+from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
+from ..conninfo import ConnectionInfo
+from ..connection import Connection
+from .._adapters_map import AdaptersMap
+from ..connection_async import AsyncConnection
+from ._types import adapters
+
+if TYPE_CHECKING:
+ from ..pq.abc import PGconn
+ from ..cursor import Cursor
+ from ..cursor_async import AsyncCursor
+
+
+class _CrdbConnectionMixin:
+
+ _adapters: Optional[AdaptersMap]
+ pgconn: "PGconn"
+
+ @classmethod
+ def is_crdb(
+ cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
+ ) -> bool:
+ """
+ Return `!True` if the server connected to `!conn` is CockroachDB.
+ """
+ if isinstance(conn, (Connection, AsyncConnection)):
+ conn = conn.pgconn
+
+ return bool(conn.parameter_status(b"crdb_version"))
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ # By default, use CockroachDB adapters map
+ self._adapters = AdaptersMap(adapters)
+
+ return self._adapters
+
+ @property
+ def info(self) -> "CrdbConnectionInfo":
+ return CrdbConnectionInfo(self.pgconn)
+
+ def _check_tpc(self) -> None:
+ if self.is_crdb(self.pgconn):
+ raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
+
+
+class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
+ """
+ Wrapper for a connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+ """
+ Connect to a database server and return a new `CrdbConnection` instance.
+ """
+ return super().connect(conninfo, **kwargs) # type: ignore[return-value]
+
+
+class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
+ """
+ Wrapper for an async connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ async def connect(
+ cls, conninfo: str = "", **kwargs: Any
+ ) -> "AsyncCrdbConnection[Any]":
+ return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
+
+
+class CrdbConnectionInfo(ConnectionInfo):
+ """
+ `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ @property
+ def vendor(self) -> str:
+ return "CockroachDB"
+
+ @property
+ def server_version(self) -> int:
+ """
+ Return the CockroachDB server version connected.
+
+ Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
+ """
+ sver = self.parameter_status("crdb_version")
+ if not sver:
+ raise e.InternalError("'crdb_version' parameter status not set")
+
+ ver = self.parse_crdb_version(sver)
+ if ver is None:
+ raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
+
+ return ver
+
+ @classmethod
+ def parse_crdb_version(self, sver: str) -> Optional[int]:
+ m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+ if not m:
+ return None
+
+ return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py
new file mode 100644
index 0000000..42c3804
--- /dev/null
+++ b/psycopg/psycopg/cursor.py
@@ -0,0 +1,921 @@
+"""
+psycopg cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from functools import partial
+from types import TracebackType
+from typing import Any, Generic, Iterable, Iterator, List
+from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
+from typing import overload, TYPE_CHECKING
+from contextlib import contextmanager
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .copy import Copy, Writer as CopyWriter
+from .rows import Row, RowMaker, RowFactory
+from ._column import Column
+from ._queries import PostgresQuery, PostgresClientQuery
+from ._pipeline import Pipeline
+from ._encodings import pgconn_encoding
+from ._preparing import Prepare
+from .generators import execute, fetch, send
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+ from .pq.abc import PGconn, PGresult
+ from .connection import Connection
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class BaseCursor(Generic[ConnectionType, Row]):
+ __slots__ = """
+ _conn format _adapters arraysize _closed _results pgresult _pos
+ _iresult _rowcount _query _tx _last_query _row_factory _make_row
+ _pgconn _execmany_returning
+ __weakref__
+ """.split()
+
+ ExecStatus = pq.ExecStatus
+
+ _tx: "Transformer"
+ _make_row: RowMaker[Row]
+ _pgconn: "PGconn"
+
+ def __init__(self, connection: ConnectionType):
+ self._conn = connection
+ self.format = TEXT
+ self._pgconn = connection.pgconn
+ self._adapters = adapt.AdaptersMap(connection.adapters)
+ self.arraysize = 1
+ self._closed = False
+ self._last_query: Optional[Query] = None
+ self._reset()
+
+ def _reset(self, reset_query: bool = True) -> None:
+ self._results: List["PGresult"] = []
+ self.pgresult: Optional["PGresult"] = None
+ self._pos = 0
+ self._iresult = 0
+ self._rowcount = -1
+ self._query: Optional[PostgresQuery]
+ # None if executemany() not executing, True/False according to returning state
+ self._execmany_returning: Optional[bool] = None
+ if reset_query:
+ self._query = None
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ if self._closed:
+ status = "closed"
+ elif self.pgresult:
+ status = pq.ExecStatus(self.pgresult.status).name
+ else:
+ status = "no result"
+ return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
+
+ @property
+ def connection(self) -> ConnectionType:
+ """The connection this cursor is using."""
+ return self._conn
+
+ @property
+ def adapters(self) -> adapt.AdaptersMap:
+ return self._adapters
+
+ @property
+ def closed(self) -> bool:
+ """`True` if the cursor is closed."""
+ return self._closed
+
+ @property
+ def description(self) -> Optional[List[Column]]:
+ """
+ A list of `Column` objects describing the current resultset.
+
+ `!None` if the current resultset didn't return tuples.
+ """
+ res = self.pgresult
+
+ # We return columns if we have nfields, but also if we don't but
+ # the query said we got tuples (mostly to handle the super useful
+ # query "SELECT ;"
+ if res and (
+ res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
+ ):
+ return [Column(self, i) for i in range(res.nfields)]
+ else:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ """Number of records affected by the precedent operation."""
+ return self._rowcount
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+ return self._pos if tuples else None
+
+ def setinputsizes(self, sizes: Sequence[Any]) -> None:
+ # no-op
+ pass
+
+ def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
+ # no-op
+ pass
+
+ def nextset(self) -> Optional[bool]:
+ """
+ Move to the result set of the next query executed through `executemany()`
+ or to the next result set if `execute()` returned more than one.
+
+ Return `!True` if a new result is available, which will be the one
+ methods `!fetch*()` will operate on.
+ """
+ if self._iresult < len(self._results) - 1:
+ self._select_current_result(self._iresult + 1)
+ return True
+ else:
+ return None
+
+ @property
+ def statusmessage(self) -> Optional[str]:
+ """
+ The command status tag from the last SQL command executed.
+
+ `!None` if the cursor doesn't have a result available.
+ """
+ msg = self.pgresult.command_status if self.pgresult else None
+ return msg.decode() if msg else None
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ raise NotImplementedError
+
+ #
+ # Generators for the high level operations on the cursor
+ #
+ # Like for sync/async connections, these are implemented as generators
+ # so that different concurrency strategies (threads,asyncio) can use their
+ # own way of waiting (or better, `connection.wait()`).
+ #
+
+ def _execute_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `Cursor.execute()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ results = yield from self._maybe_prepare_gen(
+ pgq, prepare=prepare, binary=binary
+ )
+ if self._conn._pipeline:
+ yield from self._conn._pipeline._communicate_gen()
+ else:
+ assert results is not None
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0)
+
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines available.
+ """
+ pipeline = self._conn._pipeline
+ assert pipeline
+
+ yield from self._start_query(query)
+ self._rowcount = 0
+
+ assert self._execmany_returning is None
+ self._execmany_returning = returning
+
+ first = True
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ yield from self._maybe_prepare_gen(pgq, prepare=True)
+ yield from pipeline._communicate_gen()
+
+ self._last_query = query
+
+ if returning:
+ yield from pipeline._fetch_gen(flush=True)
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_no_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines not available.
+ """
+ yield from self._start_query(query)
+ first = True
+ nrows = 0
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ results = yield from self._maybe_prepare_gen(pgq, prepare=True)
+ assert results is not None
+ self._check_results(results)
+ if returning:
+ self._results.extend(results)
+
+ for res in results:
+ nrows += res.command_tuples or 0
+
+ if self._results:
+ self._select_current_result(0)
+
+ # Override rowcount for the first result. Calls to nextset() will change
+ # it to the value of that result only, but we hope nobody will notice.
+ # You haven't read this comment.
+ self._rowcount = nrows
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _maybe_prepare_gen(
+ self,
+ pgq: PostgresQuery,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[Optional[List["PGresult"]]]:
+ # Check if the query is prepared or needs preparing
+ prep, name = self._get_prepared(pgq, prepare)
+ if prep is Prepare.NO:
+ # The query must be executed without preparing
+ self._execute_send(pgq, binary=binary)
+ else:
+ # If the query is not already prepared, prepare it.
+ if prep is Prepare.SHOULD:
+ self._send_prepare(name, pgq)
+ if not self._conn._pipeline:
+ (result,) = yield from execute(self._pgconn)
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ # Then execute it.
+ self._send_query_prepared(name, pgq, binary=binary)
+
+ # Update the prepare state of the query.
+ # If an operation requires to flush our prepared statements cache,
+ # it will be added to the maintenance commands to execute later.
+ key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+ if self._conn._pipeline:
+ queued = None
+ if key is not None:
+ queued = (key, prep, name)
+ self._conn._pipeline.result_queue.append((self, queued))
+ return None
+
+ # run the query
+ results = yield from execute(self._pgconn)
+
+ if key is not None:
+ self._conn._prepared.validate(key, prep, name, results)
+
+ return results
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return self._conn._prepared.get(pgq, prepare)
+
+ def _stream_send_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator to send the query for `Cursor.stream()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, binary=binary, force_extended=True)
+ self._pgconn.set_single_row_mode()
+ self._last_query = query
+ yield from send(self._pgconn)
+
+ def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
+ res = yield from fetch(self._pgconn)
+ if res is None:
+ return None
+
+ status = res.status
+ if status == SINGLE_TUPLE:
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=first)
+ if first:
+ self._make_row = self._make_row_maker()
+ return res
+
+ elif status == TUPLES_OK or status == COMMAND_OK:
+ # End of single row results
+ while res:
+ res = yield from fetch(self._pgconn)
+ if status != TUPLES_OK:
+ raise e.ProgrammingError(
+ "the operation in stream() didn't produce a result"
+ )
+ return None
+
+ else:
+ # Errors, unexpected values
+ return self._raise_for_result(res)
+
+ def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
+ """Generator to start the processing of a query.
+
+ It is implemented as generator because it may send additional queries,
+ such as `begin`.
+ """
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+
+ self._reset()
+ if not self._last_query or (self._last_query is not query):
+ self._last_query = None
+ self._tx = adapt.Transformer(self)
+ yield from self._conn._start_query()
+
+ def _start_copy_gen(
+ self, statement: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
+ """Generator implementing sending a command for `Cursor.copy()."""
+
+ # The connection gets in an unrecoverable state if we attempt COPY in
+ # pipeline mode. Forbid it explicitly.
+ if self._conn._pipeline:
+ raise e.NotSupportedError("COPY cannot be used in pipeline mode")
+
+ yield from self._start_query()
+
+ # Merge the params client-side
+ if params:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(statement, params)
+ statement = pgq.query
+
+ query = self._convert_query(statement)
+
+ self._execute_send(query, binary=False)
+ results = yield from execute(self._pgconn)
+ if len(results) != 1:
+ raise e.ProgrammingError("COPY cannot be mixed with other operations")
+
+ self._check_copy_result(results[0])
+ self._results = results
+ self._select_current_result(0)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ """
+ Implement part of execute() before waiting common to sync and async.
+
+ This is not a generator, but a normal non-blocking function.
+ """
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_params,
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ )
+ elif force_extended or query.params or fmt == BINARY:
+ self._pgconn.send_query_params(
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _check_results(self, results: List["PGresult"]) -> None:
+ """
+ Verify that the results of a query are valid.
+
+ Verify that the query returned at least one result and that they all
+ represent a valid result from the database.
+ """
+ if not results:
+ raise e.InternalError("got no result from the query")
+
+ for res in results:
+ status = res.status
+ if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
+ self._raise_for_result(res)
+
+ def _raise_for_result(self, result: "PGresult") -> NoReturn:
+ """
+ Raise an appropriate error message for an unexpected database result
+ """
+ status = result.status
+ if status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ raise e.ProgrammingError(
+ "COPY cannot be used with this method; use copy() instead"
+ )
+ else:
+ raise e.InternalError(
+ "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
+ )
+
+ def _select_current_result(
+ self, i: int, format: Optional[pq.Format] = None
+ ) -> None:
+ """
+ Select one of the results in the cursor as the active one.
+ """
+ self._iresult = i
+ res = self.pgresult = self._results[i]
+
+ # Note: the only reason to override format is to correctly set
+ # binary loaders on server-side cursors, because send_describe_portal
+ # only returns a text result.
+ self._tx.set_pgresult(res, format=format)
+
+ self._pos = 0
+
+ if res.status == TUPLES_OK:
+ self._rowcount = self.pgresult.ntuples
+
+ # COPY_OUT has never info about nrows. We need such result for the
+ # columns in order to return a `description`, but not overwrite the
+ # cursor rowcount (which was set by the Copy object).
+ elif res.status != COPY_OUT:
+ nrows = self.pgresult.command_tuples
+ self._rowcount = nrows if nrows is not None else -1
+
+ self._make_row = self._make_row_maker()
+
+ def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
+ self._check_results(results)
+ first_batch = not self._results
+
+ if self._execmany_returning is None:
+ # Received from execute()
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+
+ else:
+ # Received from executemany()
+ if self._execmany_returning:
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+ self._rowcount = 0
+
+ # Override rowcount for the first result. Calls to nextset() will
+ # change it to the value of that result only, but we hope nobody
+ # will notice.
+ # You haven't read this comment.
+ if self._rowcount < 0:
+ self._rowcount = 0
+ for res in results:
+ self._rowcount += res.command_tuples or 0
+
+ def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_prepare,
+ name,
+ query.query,
+ param_types=query.types,
+ )
+ )
+ self._conn._pipeline.result_queue.append(None)
+ else:
+ self._pgconn.send_prepare(name, query.query, param_types=query.types)
+
+ def _send_query_prepared(
+ self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_prepared,
+ name,
+ pgq.params,
+ param_formats=pgq.formats,
+ result_format=fmt,
+ )
+ )
+ else:
+ self._pgconn.send_query_prepared(
+ name, pgq.params, param_formats=pgq.formats, result_format=fmt
+ )
+
+ def _check_result_for_fetch(self) -> None:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ res = self.pgresult
+ if not res:
+ raise e.ProgrammingError("no result available")
+
+ status = res.status
+ if status == TUPLES_OK:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(res, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ raise e.ProgrammingError("the last operation didn't produce a result")
+
+ def _check_copy_result(self, result: "PGresult") -> None:
+ """
+ Check that the value returned in a copy() operation is a legit COPY.
+ """
+ status = result.status
+ if status == COPY_IN or status == COPY_OUT:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ else:
+ raise e.ProgrammingError(
+ "copy() should be used only with COPY ... TO STDOUT or COPY ..."
+ f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
+ )
+
+ def _scroll(self, value: int, mode: str) -> None:
+ self._check_result_for_fetch()
+ assert self.pgresult
+ if mode == "relative":
+ newpos = self._pos + value
+ elif mode == "absolute":
+ newpos = value
+ else:
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ if not 0 <= newpos < self.pgresult.ntuples:
+ raise IndexError("position out of bound")
+ self._pos = newpos
+
+ def _close(self) -> None:
+ """Non-blocking part of closing. Common to sync/async."""
+ # Don't reset the query because it may be useful to investigate after
+ # an error.
+ self._reset(reset_query=False)
+ self._closed = True
+
+ @property
+ def _encoding(self) -> str:
+ return pgconn_encoding(self._pgconn)
+
+
+class Cursor(BaseCursor["Connection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="Cursor[Any]")
+
+ @overload
+ def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "Cursor[Row]",
+ connection: "Connection[Any]",
+ *,
+ row_factory: RowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ self._close()
+
+ @property
+ def row_factory(self) -> RowFactory[Row]:
+ """Writable attribute to control how result rows are formed."""
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: RowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ """
+ Execute a query or command to the database.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ """
+ Execute the same command with a sequence of input data.
+ """
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ with self._conn.pipeline(), self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ with self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> Iterator[Row]:
+ """
+ Iterate row-by-row on a result from the database.
+ """
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ with self._conn.lock:
+
+ try:
+ self._conn.wait(self._stream_send_gen(query, params, binary=binary))
+ first = True
+ while self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while self._conn.wait(self._stream_fetchone_gen(first=False)):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ def fetchone(self) -> Optional[Row]:
+ """
+ Return the next record from the current recordset.
+
+ Return `!None` the recordset is finished.
+
+ :rtype: Optional[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ record = self._tx.load_row(self._pos, self._make_row)
+ if record is not None:
+ self._pos += 1
+ return record
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ """
+ Return the next `!size` records from the current recordset.
+
+ `!size` default to `!self.arraysize` if not specified.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ def fetchall(self) -> List[Row]:
+ """
+ Return all the remaining records from the current recordset.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ def __iter__(self) -> Iterator[Row]:
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ """
+ Move the cursor in the result set to a new position according to mode.
+
+ If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
+ the current position in the result set; if set to ``'absolute'``,
+ `!value` states an absolute target position.
+
+ Raise `!IndexError` in case a scroll operation would leave the result
+ set. In this case the position will not change.
+ """
+ self._fetch_pipeline()
+ self._scroll(value, mode)
+
+ @contextmanager
+ def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[CopyWriter] = None,
+ ) -> Iterator[Copy]:
+ """
+ Initiate a :sql:`COPY` operation and return an object to manage it.
+
+ :rtype: Copy
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._start_copy_gen(statement, params))
+
+ with Copy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ # If a fresher result has been set on the cursor by the Copy object,
+ # read its properties (especially rowcount).
+ self._select_current_result(0)
+
+ def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ with self._conn.lock:
+ self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py
new file mode 100644
index 0000000..8971d40
--- /dev/null
+++ b/psycopg/psycopg/cursor_async.py
@@ -0,0 +1,250 @@
+"""
+psycopg async cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from types import TracebackType
+from typing import Any, AsyncIterator, Iterable, List
+from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from .abc import Query, Params
+from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
+from .rows import Row, RowMaker, AsyncRowFactory
+from .cursor import BaseCursor
+from ._pipeline import Pipeline
+
+if TYPE_CHECKING:
+ from .connection_async import AsyncConnection
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncCursor[Any]")
+
+ @overload
+ def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: AsyncRowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ async def close(self) -> None:
+ self._close()
+
+ @property
+ def row_factory(self) -> AsyncRowFactory[Row]:
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ async with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ async with self._conn.pipeline(), self._conn.lock:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ await self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> AsyncIterator[Row]:
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ async with self._conn.lock:
+
+ try:
+ await self._conn.wait(
+ self._stream_send_gen(query, params, binary=binary)
+ )
+ first = True
+ while await self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while await self._conn.wait(
+ self._stream_fetchone_gen(first=False)
+ ):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ await self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ async def fetchone(self) -> Optional[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ rv = self._tx.load_row(self._pos, self._make_row)
+ if rv is not None:
+ self._pos += 1
+ return rv
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ async def fetchall(self) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ self._scroll(value, mode)
+
+ @asynccontextmanager
+ async def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[AsyncCopyWriter] = None,
+ ) -> AsyncIterator[AsyncCopy]:
+ """
+ :rtype: AsyncCopy
+ """
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._start_copy_gen(statement, params))
+
+ async with AsyncCopy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ self._select_current_result(0)
+
+ async def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ async with self._conn.lock:
+ await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py
new file mode 100644
index 0000000..3c3d8b7
--- /dev/null
+++ b/psycopg/psycopg/dbapi20.py
@@ -0,0 +1,112 @@
+"""
+Compatibility objects with DBAPI 2.0
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import time
+import datetime as dt
+from math import floor
+from typing import Any, Sequence, Union
+
+from . import postgres
+from .abc import AdaptContext, Buffer
+from .types.string import BytesDumper, BytesBinaryDumper
+
+
+class DBAPITypeObject:
+ def __init__(self, name: str, type_names: Sequence[str]):
+ self.name = name
+ self.values = tuple(postgres.types[n].oid for n in type_names)
+
+ def __repr__(self) -> str:
+ return f"psycopg.{self.name}"
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other in self.values
+ else:
+ return NotImplemented
+
+ def __ne__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other not in self.values
+ else:
+ return NotImplemented
+
+
+BINARY = DBAPITypeObject("BINARY", ("bytea",))
+DATETIME = DBAPITypeObject(
+ "DATETIME", "timestamp timestamptz date time timetz interval".split()
+)
+NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
+ROWID = DBAPITypeObject("ROWID", ("oid",))
+STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
+
+
+class Binary:
+ def __init__(self, obj: Any):
+ self.obj = obj
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class BinaryBinaryDumper(BytesBinaryDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+class BinaryTextDumper(BytesDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+def Date(year: int, month: int, day: int) -> dt.date:
+ return dt.date(year, month, day)
+
+
+def DateFromTicks(ticks: float) -> dt.date:
+ return TimestampFromTicks(ticks).date()
+
+
+def Time(hour: int, minute: int, second: int) -> dt.time:
+ return dt.time(hour, minute, second)
+
+
+def TimeFromTicks(ticks: float) -> dt.time:
+ return TimestampFromTicks(ticks).time()
+
+
+def Timestamp(
+ year: int, month: int, day: int, hour: int, minute: int, second: int
+) -> dt.datetime:
+ return dt.datetime(year, month, day, hour, minute, second)
+
+
+def TimestampFromTicks(ticks: float) -> dt.datetime:
+ secs = floor(ticks)
+ frac = ticks - secs
+ t = time.localtime(ticks)
+ tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
+ rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
+ return rv
+
+
+def register_dbapi20_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Binary, BinaryTextDumper)
+ adapters.register_dumper(Binary, BinaryBinaryDumper)
+
+ # Make them also the default dumpers when dumping by bytea oid
+ adapters.register_dumper(None, BinaryTextDumper)
+ adapters.register_dumper(None, BinaryBinaryDumper)
diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py
new file mode 100644
index 0000000..e176954
--- /dev/null
+++ b/psycopg/psycopg/errors.py
@@ -0,0 +1,1535 @@
+"""
+psycopg exceptions
+
+DBAPI-defined Exceptions are defined in the following hierarchy::
+
+ Exceptions
+ |__Warning
+ |__Error
+ |__InterfaceError
+ |__DatabaseError
+ |__DataError
+ |__OperationalError
+ |__IntegrityError
+ |__InternalError
+ |__ProgrammingError
+ |__NotSupportedError
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
+from typing_extensions import TypeAlias
+
+from .pq.abc import PGconn, PGresult
+from .pq._enums import DiagnosticField
+from ._compat import TypeGuard
+
+ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]]
+
+_sqlcodes: Dict[str, "Type[Error]"] = {}
+
+
+class Warning(Exception):
+ """
+ Exception raised for important warnings.
+
+ Defined for DBAPI compatibility, but never raised by ``psycopg``.
+ """
+
+ __module__ = "psycopg"
+
+
+class Error(Exception):
+ """
+ Base exception for all the errors psycopg will raise.
+
+ Exception that is the base class of all other error exceptions. You can
+ use this to catch all errors with one single `!except` statement.
+
+ This exception is guaranteed to be picklable.
+ """
+
+ __module__ = "psycopg"
+
+ sqlstate: Optional[str] = None
+
+ def __init__(
+ self,
+ *args: Sequence[Any],
+ info: ErrorInfo = None,
+ encoding: str = "utf-8",
+ pgconn: Optional[PGconn] = None
+ ):
+ super().__init__(*args)
+ self._info = info
+ self._encoding = encoding
+ self._pgconn = pgconn
+
+ # Handle sqlstate codes for which we don't have a class.
+ if not self.sqlstate and info:
+ self.sqlstate = self.diag.sqlstate
+
+ @property
+ def pgconn(self) -> Optional[PGconn]:
+ """The connection object, if the error was raised from a connection attempt.
+
+ :rtype: Optional[psycopg.pq.PGconn]
+ """
+ return self._pgconn if self._pgconn else None
+
+ @property
+ def pgresult(self) -> Optional[PGresult]:
+ """The result object, if the exception was raised after a failed query.
+
+ :rtype: Optional[psycopg.pq.PGresult]
+ """
+ return self._info if _is_pgresult(self._info) else None
+
+ @property
+ def diag(self) -> "Diagnostic":
+ """
+ A `Diagnostic` object to inspect details of the errors from the database.
+ """
+ return Diagnostic(self._info, encoding=self._encoding)
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ # To make the exception picklable
+ res[2]["_info"] = _info_to_dict(self._info)
+ res[2]["_pgconn"] = None
+
+ return res
+
+
+class InterfaceError(Error):
+ """
+ An error related to the database interface rather than the database itself.
+ """
+
+ __module__ = "psycopg"
+
+
+class DatabaseError(Error):
+ """
+ Exception raised for errors that are related to the database.
+ """
+
+ __module__ = "psycopg"
+
+ def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None):
+ if code:
+ _sqlcodes[code] = cls
+ cls.sqlstate = code
+ if name:
+ _sqlcodes[name] = cls
+
+
+class DataError(DatabaseError):
+ """
+ An error caused by problems with the processed data.
+
+ Examples may be division by zero, numeric value out of range, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class OperationalError(DatabaseError):
+ """
+ An error related to the database's operation.
+
+ These errors are not necessarily under the control of the programmer, e.g.
+ an unexpected disconnect occurs, the data source name is not found, a
+ transaction could not be processed, a memory allocation error occurred
+ during processing, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class IntegrityError(DatabaseError):
+ """
+ An error caused when the relational integrity of the database is affected.
+
+ An example may be a foreign key check failed.
+ """
+
+ __module__ = "psycopg"
+
+
+class InternalError(DatabaseError):
+ """
+ An error generated when the database encounters an internal error,
+
+ Examples could be the cursor is not valid anymore, the transaction is out
+ of sync, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class ProgrammingError(DatabaseError):
+ """
+ Exception raised for programming errors
+
+ Examples may be table not found or already exists, syntax error in the SQL
+ statement, wrong number of parameters specified, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class NotSupportedError(DatabaseError):
+ """
+ A method or database API was used which is not supported by the database.
+ """
+
+ __module__ = "psycopg"
+
+
+class ConnectionTimeout(OperationalError):
+ """
+ Exception raised on timeout of the `~psycopg.Connection.connect()` method.
+
+ The error is raised if the ``connect_timeout`` is specified and a
+ connection is not obtained in useful time.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class PipelineAborted(OperationalError):
+ """
+ Raised when a operation fails because the current pipeline is in aborted state.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class Diagnostic:
+ """Details from a database error report."""
+
+ def __init__(self, info: ErrorInfo, encoding: str = "utf-8"):
+ self._info = info
+ self._encoding = encoding
+
+ @property
+ def severity(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY)
+
+ @property
+ def severity_nonlocalized(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED)
+
+ @property
+ def sqlstate(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SQLSTATE)
+
+ @property
+ def message_primary(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_PRIMARY)
+
+ @property
+ def message_detail(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_DETAIL)
+
+ @property
+ def message_hint(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_HINT)
+
+ @property
+ def statement_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.STATEMENT_POSITION)
+
+ @property
+ def internal_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_POSITION)
+
+ @property
+ def internal_query(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_QUERY)
+
+ @property
+ def context(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONTEXT)
+
+ @property
+ def schema_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SCHEMA_NAME)
+
+ @property
+ def table_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.TABLE_NAME)
+
+ @property
+ def column_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.COLUMN_NAME)
+
+ @property
+ def datatype_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.DATATYPE_NAME)
+
+ @property
+ def constraint_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONSTRAINT_NAME)
+
+ @property
+ def source_file(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FILE)
+
+ @property
+ def source_line(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_LINE)
+
+ @property
+ def source_function(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FUNCTION)
+
+ def _error_message(self, field: DiagnosticField) -> Optional[str]:
+ if self._info:
+ if isinstance(self._info, dict):
+ val = self._info.get(field)
+ else:
+ val = self._info.error_field(field)
+
+ if val is not None:
+ return val.decode(self._encoding, "replace")
+
+ return None
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ res[2]["_info"] = _info_to_dict(self._info)
+
+ return res
+
+
+def _info_to_dict(info: ErrorInfo) -> ErrorInfo:
+ """
+ Convert a PGresult to a dictionary to make the info picklable.
+ """
+ # PGresult is a protocol, can't use isinstance
+ if _is_pgresult(info):
+ return {v: info.error_field(v) for v in DiagnosticField}
+ else:
+ return info
+
+
+def lookup(sqlstate: str) -> Type[Error]:
+ """Lookup an error code or `constant name`__ and return its exception class.
+
+ Raise `!KeyError` if the code is not found.
+
+ .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html
+ #ERRCODES-TABLE
+ """
+ return _sqlcodes[sqlstate.upper()]
+
+
+def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
+ from psycopg import pq
+
+ state = result.error_field(DiagnosticField.SQLSTATE) or b""
+ cls = _class_for_state(state.decode("ascii"))
+ return cls(
+ pq.error_message(result, encoding=encoding),
+ info=result,
+ encoding=encoding,
+ )
+
+
+def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]:
+ """Return True if an ErrorInfo is a PGresult instance."""
+ # PGresult is a protocol, can't use isinstance
+ return hasattr(info, "error_field")
+
+
+def _class_for_state(sqlstate: str) -> Type[Error]:
+ try:
+ return lookup(sqlstate)
+ except KeyError:
+ return get_base_exception(sqlstate)
+
+
+def get_base_exception(sqlstate: str) -> Type[Error]:
+ return (
+ _base_exc_map.get(sqlstate[:2])
+ or _base_exc_map.get(sqlstate[:1])
+ or DatabaseError
+ )
+
+
+_base_exc_map = {
+ "08": OperationalError, # Connection Exception
+ "0A": NotSupportedError, # Feature Not Supported
+ "20": ProgrammingError, # Case Not Foud
+ "21": ProgrammingError, # Cardinality Violation
+ "22": DataError, # Data Exception
+ "23": IntegrityError, # Integrity Constraint Violation
+ "24": InternalError, # Invalid Cursor State
+ "25": InternalError, # Invalid Transaction State
+ "26": ProgrammingError, # Invalid SQL Statement Name *
+ "27": OperationalError, # Triggered Data Change Violation
+ "28": OperationalError, # Invalid Authorization Specification
+ "2B": InternalError, # Dependent Privilege Descriptors Still Exist
+ "2D": InternalError, # Invalid Transaction Termination
+ "2F": OperationalError, # SQL Routine Exception *
+ "34": ProgrammingError, # Invalid Cursor Name *
+ "38": OperationalError, # External Routine Exception *
+ "39": OperationalError, # External Routine Invocation Exception *
+ "3B": OperationalError, # Savepoint Exception *
+ "3D": ProgrammingError, # Invalid Catalog Name
+ "3F": ProgrammingError, # Invalid Schema Name
+ "40": OperationalError, # Transaction Rollback
+ "42": ProgrammingError, # Syntax Error or Access Rule Violation
+ "44": ProgrammingError, # WITH CHECK OPTION Violation
+ "53": OperationalError, # Insufficient Resources
+ "54": OperationalError, # Program Limit Exceeded
+ "55": OperationalError, # Object Not In Prerequisite State
+ "57": OperationalError, # Operator Intervention
+ "58": OperationalError, # System Error (errors external to PostgreSQL itself)
+ "F": OperationalError, # Configuration File Error
+ "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED)
+ "P": ProgrammingError, # PL/pgSQL Error
+ "X": InternalError, # Internal Error
+}
+
+
+# Error classes generated by tools/update_errors.py
+
+# fmt: off
+# autogenerated: start
+
+
+# Class 02 - No Data (this is also a warning class per the SQL standard)
+
+class NoData(DatabaseError,
+ code='02000', name='NO_DATA'):
+ pass
+
+class NoAdditionalDynamicResultSetsReturned(DatabaseError,
+ code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'):
+ pass
+
+
+# Class 03 - SQL Statement Not Yet Complete
+
+class SqlStatementNotYetComplete(DatabaseError,
+ code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'):
+ pass
+
+
+# Class 08 - Connection Exception
+
+class ConnectionException(OperationalError,
+ code='08000', name='CONNECTION_EXCEPTION'):
+ pass
+
+class SqlclientUnableToEstablishSqlconnection(OperationalError,
+ code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'):
+ pass
+
+class ConnectionDoesNotExist(OperationalError,
+ code='08003', name='CONNECTION_DOES_NOT_EXIST'):
+ pass
+
+class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError,
+ code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'):
+ pass
+
+class ConnectionFailure(OperationalError,
+ code='08006', name='CONNECTION_FAILURE'):
+ pass
+
+class TransactionResolutionUnknown(OperationalError,
+ code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'):
+ pass
+
+class ProtocolViolation(OperationalError,
+ code='08P01', name='PROTOCOL_VIOLATION'):
+ pass
+
+
+# Class 09 - Triggered Action Exception
+
+class TriggeredActionException(DatabaseError,
+ code='09000', name='TRIGGERED_ACTION_EXCEPTION'):
+ pass
+
+
+# Class 0A - Feature Not Supported
+
+class FeatureNotSupported(NotSupportedError,
+ code='0A000', name='FEATURE_NOT_SUPPORTED'):
+ pass
+
+
+# Class 0B - Invalid Transaction Initiation
+
+class InvalidTransactionInitiation(DatabaseError,
+ code='0B000', name='INVALID_TRANSACTION_INITIATION'):
+ pass
+
+
+# Class 0F - Locator Exception
+
+class LocatorException(DatabaseError,
+ code='0F000', name='LOCATOR_EXCEPTION'):
+ pass
+
+class InvalidLocatorSpecification(DatabaseError,
+ code='0F001', name='INVALID_LOCATOR_SPECIFICATION'):
+ pass
+
+
+# Class 0L - Invalid Grantor
+
+class InvalidGrantor(DatabaseError,
+ code='0L000', name='INVALID_GRANTOR'):
+ pass
+
+class InvalidGrantOperation(DatabaseError,
+ code='0LP01', name='INVALID_GRANT_OPERATION'):
+ pass
+
+
+# Class 0P - Invalid Role Specification
+
+class InvalidRoleSpecification(DatabaseError,
+ code='0P000', name='INVALID_ROLE_SPECIFICATION'):
+ pass
+
+
+# Class 0Z - Diagnostics Exception
+
+class DiagnosticsException(DatabaseError,
+ code='0Z000', name='DIAGNOSTICS_EXCEPTION'):
+ pass
+
+class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError,
+ code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'):
+ pass
+
+
+# Class 20 - Case Not Found
+
+class CaseNotFound(ProgrammingError,
+ code='20000', name='CASE_NOT_FOUND'):
+ pass
+
+
+# Class 21 - Cardinality Violation
+
+class CardinalityViolation(ProgrammingError,
+ code='21000', name='CARDINALITY_VIOLATION'):
+ pass
+
+
+# Class 22 - Data Exception
+
+class DataException(DataError,
+ code='22000', name='DATA_EXCEPTION'):
+ pass
+
+class StringDataRightTruncation(DataError,
+ code='22001', name='STRING_DATA_RIGHT_TRUNCATION'):
+ pass
+
+class NullValueNoIndicatorParameter(DataError,
+ code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'):
+ pass
+
+class NumericValueOutOfRange(DataError,
+ code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'):
+ pass
+
+class NullValueNotAllowed(DataError,
+ code='22004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class ErrorInAssignment(DataError,
+ code='22005', name='ERROR_IN_ASSIGNMENT'):
+ pass
+
+class InvalidDatetimeFormat(DataError,
+ code='22007', name='INVALID_DATETIME_FORMAT'):
+ pass
+
+class DatetimeFieldOverflow(DataError,
+ code='22008', name='DATETIME_FIELD_OVERFLOW'):
+ pass
+
+class InvalidTimeZoneDisplacementValue(DataError,
+ code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'):
+ pass
+
+class EscapeCharacterConflict(DataError,
+ code='2200B', name='ESCAPE_CHARACTER_CONFLICT'):
+ pass
+
+class InvalidUseOfEscapeCharacter(DataError,
+ code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidEscapeOctet(DataError,
+ code='2200D', name='INVALID_ESCAPE_OCTET'):
+ pass
+
+class ZeroLengthCharacterString(DataError,
+ code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'):
+ pass
+
+class MostSpecificTypeMismatch(DataError,
+ code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'):
+ pass
+
+class SequenceGeneratorLimitExceeded(DataError,
+ code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'):
+ pass
+
+class NotAnXmlDocument(DataError,
+ code='2200L', name='NOT_AN_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlDocument(DataError,
+ code='2200M', name='INVALID_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlContent(DataError,
+ code='2200N', name='INVALID_XML_CONTENT'):
+ pass
+
+class InvalidXmlComment(DataError,
+ code='2200S', name='INVALID_XML_COMMENT'):
+ pass
+
+class InvalidXmlProcessingInstruction(DataError,
+ code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'):
+ pass
+
+class InvalidIndicatorParameterValue(DataError,
+ code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'):
+ pass
+
+class SubstringError(DataError,
+ code='22011', name='SUBSTRING_ERROR'):
+ pass
+
+class DivisionByZero(DataError,
+ code='22012', name='DIVISION_BY_ZERO'):
+ pass
+
+class InvalidPrecedingOrFollowingSize(DataError,
+ code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'):
+ pass
+
+class InvalidArgumentForNtileFunction(DataError,
+ code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'):
+ pass
+
+class IntervalFieldOverflow(DataError,
+ code='22015', name='INTERVAL_FIELD_OVERFLOW'):
+ pass
+
+class InvalidArgumentForNthValueFunction(DataError,
+ code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'):
+ pass
+
+class InvalidCharacterValueForCast(DataError,
+ code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'):
+ pass
+
+class InvalidEscapeCharacter(DataError,
+ code='22019', name='INVALID_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidRegularExpression(DataError,
+ code='2201B', name='INVALID_REGULAR_EXPRESSION'):
+ pass
+
+class InvalidArgumentForLogarithm(DataError,
+ code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'):
+ pass
+
+class InvalidArgumentForPowerFunction(DataError,
+ code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'):
+ pass
+
+class InvalidArgumentForWidthBucketFunction(DataError,
+ code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'):
+ pass
+
+class InvalidRowCountInLimitClause(DataError,
+ code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'):
+ pass
+
+class InvalidRowCountInResultOffsetClause(DataError,
+ code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'):
+ pass
+
+class CharacterNotInRepertoire(DataError,
+ code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'):
+ pass
+
+class IndicatorOverflow(DataError,
+ code='22022', name='INDICATOR_OVERFLOW'):
+ pass
+
+class InvalidParameterValue(DataError,
+ code='22023', name='INVALID_PARAMETER_VALUE'):
+ pass
+
+class UnterminatedCString(DataError,
+ code='22024', name='UNTERMINATED_C_STRING'):
+ pass
+
+class InvalidEscapeSequence(DataError,
+ code='22025', name='INVALID_ESCAPE_SEQUENCE'):
+ pass
+
+class StringDataLengthMismatch(DataError,
+ code='22026', name='STRING_DATA_LENGTH_MISMATCH'):
+ pass
+
+class TrimError(DataError,
+ code='22027', name='TRIM_ERROR'):
+ pass
+
+class ArraySubscriptError(DataError,
+ code='2202E', name='ARRAY_SUBSCRIPT_ERROR'):
+ pass
+
+class InvalidTablesampleRepeat(DataError,
+ code='2202G', name='INVALID_TABLESAMPLE_REPEAT'):
+ pass
+
+class InvalidTablesampleArgument(DataError,
+ code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'):
+ pass
+
+class DuplicateJsonObjectKeyValue(DataError,
+ code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'):
+ pass
+
+class InvalidArgumentForSqlJsonDatetimeFunction(DataError,
+ code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'):
+ pass
+
+class InvalidJsonText(DataError,
+ code='22032', name='INVALID_JSON_TEXT'):
+ pass
+
+class InvalidSqlJsonSubscript(DataError,
+ code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'):
+ pass
+
+class MoreThanOneSqlJsonItem(DataError,
+ code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'):
+ pass
+
+class NoSqlJsonItem(DataError,
+ code='22035', name='NO_SQL_JSON_ITEM'):
+ pass
+
+class NonNumericSqlJsonItem(DataError,
+ code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'):
+ pass
+
+class NonUniqueKeysInAJsonObject(DataError,
+ code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'):
+ pass
+
+class SingletonSqlJsonItemRequired(DataError,
+ code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'):
+ pass
+
+class SqlJsonArrayNotFound(DataError,
+ code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'):
+ pass
+
+class SqlJsonMemberNotFound(DataError,
+ code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonNumberNotFound(DataError,
+ code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonObjectNotFound(DataError,
+ code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'):
+ pass
+
+class TooManyJsonArrayElements(DataError,
+ code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'):
+ pass
+
+class TooManyJsonObjectMembers(DataError,
+ code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'):
+ pass
+
+class SqlJsonScalarRequired(DataError,
+ code='2203F', name='SQL_JSON_SCALAR_REQUIRED'):
+ pass
+
+class SqlJsonItemCannotBeCastToTargetType(DataError,
+ code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'):
+ pass
+
+class FloatingPointException(DataError,
+ code='22P01', name='FLOATING_POINT_EXCEPTION'):
+ pass
+
+class InvalidTextRepresentation(DataError,
+ code='22P02', name='INVALID_TEXT_REPRESENTATION'):
+ pass
+
+class InvalidBinaryRepresentation(DataError,
+ code='22P03', name='INVALID_BINARY_REPRESENTATION'):
+ pass
+
+class BadCopyFileFormat(DataError,
+ code='22P04', name='BAD_COPY_FILE_FORMAT'):
+ pass
+
+class UntranslatableCharacter(DataError,
+ code='22P05', name='UNTRANSLATABLE_CHARACTER'):
+ pass
+
+class NonstandardUseOfEscapeCharacter(DataError,
+ code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+
+# Class 23 - Integrity Constraint Violation
+
+class IntegrityConstraintViolation(IntegrityError,
+ code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class RestrictViolation(IntegrityError,
+ code='23001', name='RESTRICT_VIOLATION'):
+ pass
+
+class NotNullViolation(IntegrityError,
+ code='23502', name='NOT_NULL_VIOLATION'):
+ pass
+
+class ForeignKeyViolation(IntegrityError,
+ code='23503', name='FOREIGN_KEY_VIOLATION'):
+ pass
+
+class UniqueViolation(IntegrityError,
+ code='23505', name='UNIQUE_VIOLATION'):
+ pass
+
+class CheckViolation(IntegrityError,
+ code='23514', name='CHECK_VIOLATION'):
+ pass
+
+class ExclusionViolation(IntegrityError,
+ code='23P01', name='EXCLUSION_VIOLATION'):
+ pass
+
+
+# Class 24 - Invalid Cursor State
+
+class InvalidCursorState(InternalError,
+ code='24000', name='INVALID_CURSOR_STATE'):
+ pass
+
+
+# Class 25 - Invalid Transaction State
+
+class InvalidTransactionState(InternalError,
+ code='25000', name='INVALID_TRANSACTION_STATE'):
+ pass
+
+class ActiveSqlTransaction(InternalError,
+ code='25001', name='ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class BranchTransactionAlreadyActive(InternalError,
+ code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'):
+ pass
+
+class InappropriateAccessModeForBranchTransaction(InternalError,
+ code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class InappropriateIsolationLevelForBranchTransaction(InternalError,
+ code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class NoActiveSqlTransactionForBranchTransaction(InternalError,
+ code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class ReadOnlySqlTransaction(InternalError,
+ code='25006', name='READ_ONLY_SQL_TRANSACTION'):
+ pass
+
+class SchemaAndDataStatementMixingNotSupported(InternalError,
+ code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'):
+ pass
+
+class HeldCursorRequiresSameIsolationLevel(InternalError,
+ code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'):
+ pass
+
+class NoActiveSqlTransaction(InternalError,
+ code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class InFailedSqlTransaction(InternalError,
+ code='25P02', name='IN_FAILED_SQL_TRANSACTION'):
+ pass
+
+class IdleInTransactionSessionTimeout(InternalError,
+ code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 26 - Invalid SQL Statement Name
+
+class InvalidSqlStatementName(ProgrammingError,
+ code='26000', name='INVALID_SQL_STATEMENT_NAME'):
+ pass
+
+
+# Class 27 - Triggered Data Change Violation
+
+class TriggeredDataChangeViolation(OperationalError,
+ code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'):
+ pass
+
+
+# Class 28 - Invalid Authorization Specification
+
+class InvalidAuthorizationSpecification(OperationalError,
+ code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'):
+ pass
+
+class InvalidPassword(OperationalError,
+ code='28P01', name='INVALID_PASSWORD'):
+ pass
+
+
+# Class 2B - Dependent Privilege Descriptors Still Exist
+
+class DependentPrivilegeDescriptorsStillExist(InternalError,
+ code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'):
+ pass
+
+class DependentObjectsStillExist(InternalError,
+ code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'):
+ pass
+
+
+# Class 2D - Invalid Transaction Termination
+
+class InvalidTransactionTermination(InternalError,
+ code='2D000', name='INVALID_TRANSACTION_TERMINATION'):
+ pass
+
+
+# Class 2F - SQL Routine Exception
+
+class SqlRoutineException(OperationalError,
+ code='2F000', name='SQL_ROUTINE_EXCEPTION'):
+ pass
+
+class ModifyingSqlDataNotPermitted(OperationalError,
+ code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttempted(OperationalError,
+ code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermitted(OperationalError,
+ code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class FunctionExecutedNoReturnStatement(OperationalError,
+ code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'):
+ pass
+
+
+# Class 34 - Invalid Cursor Name
+
+class InvalidCursorName(ProgrammingError,
+ code='34000', name='INVALID_CURSOR_NAME'):
+ pass
+
+
+# Class 38 - External Routine Exception
+
+class ExternalRoutineException(OperationalError,
+ code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'):
+ pass
+
+class ContainingSqlNotPermitted(OperationalError,
+ code='38001', name='CONTAINING_SQL_NOT_PERMITTED'):
+ pass
+
+class ModifyingSqlDataNotPermittedExt(OperationalError,
+ code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttemptedExt(OperationalError,
+ code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermittedExt(OperationalError,
+ code='38004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+
+# Class 39 - External Routine Invocation Exception
+
+class ExternalRoutineInvocationException(OperationalError,
+ code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'):
+ pass
+
+class InvalidSqlstateReturned(OperationalError,
+ code='39001', name='INVALID_SQLSTATE_RETURNED'):
+ pass
+
+class NullValueNotAllowedExt(OperationalError,
+ code='39004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class TriggerProtocolViolated(OperationalError,
+ code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+class SrfProtocolViolated(OperationalError,
+ code='39P02', name='SRF_PROTOCOL_VIOLATED'):
+ pass
+
+class EventTriggerProtocolViolated(OperationalError,
+ code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+
+# Class 3B - Savepoint Exception
+
+class SavepointException(OperationalError,
+ code='3B000', name='SAVEPOINT_EXCEPTION'):
+ pass
+
+class InvalidSavepointSpecification(OperationalError,
+ code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'):
+ pass
+
+
+# Class 3D - Invalid Catalog Name
+
+class InvalidCatalogName(ProgrammingError,
+ code='3D000', name='INVALID_CATALOG_NAME'):
+ pass
+
+
+# Class 3F - Invalid Schema Name
+
+class InvalidSchemaName(ProgrammingError,
+ code='3F000', name='INVALID_SCHEMA_NAME'):
+ pass
+
+
+# Class 40 - Transaction Rollback
+
+class TransactionRollback(OperationalError,
+ code='40000', name='TRANSACTION_ROLLBACK'):
+ pass
+
+class SerializationFailure(OperationalError,
+ code='40001', name='SERIALIZATION_FAILURE'):
+ pass
+
+class TransactionIntegrityConstraintViolation(OperationalError,
+ code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class StatementCompletionUnknown(OperationalError,
+ code='40003', name='STATEMENT_COMPLETION_UNKNOWN'):
+ pass
+
+class DeadlockDetected(OperationalError,
+ code='40P01', name='DEADLOCK_DETECTED'):
+ pass
+
+
+# Class 42 - Syntax Error or Access Rule Violation
+
+class SyntaxErrorOrAccessRuleViolation(ProgrammingError,
+ code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'):
+ pass
+
+class InsufficientPrivilege(ProgrammingError,
+ code='42501', name='INSUFFICIENT_PRIVILEGE'):
+ pass
+
+class SyntaxError(ProgrammingError,
+ code='42601', name='SYNTAX_ERROR'):
+ pass
+
+class InvalidName(ProgrammingError,
+ code='42602', name='INVALID_NAME'):
+ pass
+
+class InvalidColumnDefinition(ProgrammingError,
+ code='42611', name='INVALID_COLUMN_DEFINITION'):
+ pass
+
+class NameTooLong(ProgrammingError,
+ code='42622', name='NAME_TOO_LONG'):
+ pass
+
+class DuplicateColumn(ProgrammingError,
+ code='42701', name='DUPLICATE_COLUMN'):
+ pass
+
+class AmbiguousColumn(ProgrammingError,
+ code='42702', name='AMBIGUOUS_COLUMN'):
+ pass
+
+class UndefinedColumn(ProgrammingError,
+ code='42703', name='UNDEFINED_COLUMN'):
+ pass
+
+class UndefinedObject(ProgrammingError,
+ code='42704', name='UNDEFINED_OBJECT'):
+ pass
+
+class DuplicateObject(ProgrammingError,
+ code='42710', name='DUPLICATE_OBJECT'):
+ pass
+
+class DuplicateAlias(ProgrammingError,
+ code='42712', name='DUPLICATE_ALIAS'):
+ pass
+
+class DuplicateFunction(ProgrammingError,
+ code='42723', name='DUPLICATE_FUNCTION'):
+ pass
+
+class AmbiguousFunction(ProgrammingError,
+ code='42725', name='AMBIGUOUS_FUNCTION'):
+ pass
+
+class GroupingError(ProgrammingError,
+ code='42803', name='GROUPING_ERROR'):
+ pass
+
+class DatatypeMismatch(ProgrammingError,
+ code='42804', name='DATATYPE_MISMATCH'):
+ pass
+
+class WrongObjectType(ProgrammingError,
+ code='42809', name='WRONG_OBJECT_TYPE'):
+ pass
+
+class InvalidForeignKey(ProgrammingError,
+ code='42830', name='INVALID_FOREIGN_KEY'):
+ pass
+
+class CannotCoerce(ProgrammingError,
+ code='42846', name='CANNOT_COERCE'):
+ pass
+
+class UndefinedFunction(ProgrammingError,
+ code='42883', name='UNDEFINED_FUNCTION'):
+ pass
+
+class GeneratedAlways(ProgrammingError,
+ code='428C9', name='GENERATED_ALWAYS'):
+ pass
+
+class ReservedName(ProgrammingError,
+ code='42939', name='RESERVED_NAME'):
+ pass
+
+class UndefinedTable(ProgrammingError,
+ code='42P01', name='UNDEFINED_TABLE'):
+ pass
+
+class UndefinedParameter(ProgrammingError,
+ code='42P02', name='UNDEFINED_PARAMETER'):
+ pass
+
+class DuplicateCursor(ProgrammingError,
+ code='42P03', name='DUPLICATE_CURSOR'):
+ pass
+
+class DuplicateDatabase(ProgrammingError,
+ code='42P04', name='DUPLICATE_DATABASE'):
+ pass
+
+class DuplicatePreparedStatement(ProgrammingError,
+ code='42P05', name='DUPLICATE_PREPARED_STATEMENT'):
+ pass
+
+class DuplicateSchema(ProgrammingError,
+ code='42P06', name='DUPLICATE_SCHEMA'):
+ pass
+
+class DuplicateTable(ProgrammingError,
+ code='42P07', name='DUPLICATE_TABLE'):
+ pass
+
+class AmbiguousParameter(ProgrammingError,
+ code='42P08', name='AMBIGUOUS_PARAMETER'):
+ pass
+
+class AmbiguousAlias(ProgrammingError,
+ code='42P09', name='AMBIGUOUS_ALIAS'):
+ pass
+
+class InvalidColumnReference(ProgrammingError,
+ code='42P10', name='INVALID_COLUMN_REFERENCE'):
+ pass
+
+class InvalidCursorDefinition(ProgrammingError,
+ code='42P11', name='INVALID_CURSOR_DEFINITION'):
+ pass
+
+class InvalidDatabaseDefinition(ProgrammingError,
+ code='42P12', name='INVALID_DATABASE_DEFINITION'):
+ pass
+
+class InvalidFunctionDefinition(ProgrammingError,
+ code='42P13', name='INVALID_FUNCTION_DEFINITION'):
+ pass
+
+class InvalidPreparedStatementDefinition(ProgrammingError,
+ code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'):
+ pass
+
+class InvalidSchemaDefinition(ProgrammingError,
+ code='42P15', name='INVALID_SCHEMA_DEFINITION'):
+ pass
+
+class InvalidTableDefinition(ProgrammingError,
+ code='42P16', name='INVALID_TABLE_DEFINITION'):
+ pass
+
+class InvalidObjectDefinition(ProgrammingError,
+ code='42P17', name='INVALID_OBJECT_DEFINITION'):
+ pass
+
+class IndeterminateDatatype(ProgrammingError,
+ code='42P18', name='INDETERMINATE_DATATYPE'):
+ pass
+
+class InvalidRecursion(ProgrammingError,
+ code='42P19', name='INVALID_RECURSION'):
+ pass
+
+class WindowingError(ProgrammingError,
+ code='42P20', name='WINDOWING_ERROR'):
+ pass
+
+class CollationMismatch(ProgrammingError,
+ code='42P21', name='COLLATION_MISMATCH'):
+ pass
+
+class IndeterminateCollation(ProgrammingError,
+ code='42P22', name='INDETERMINATE_COLLATION'):
+ pass
+
+
+# Class 44 - WITH CHECK OPTION Violation
+
+class WithCheckOptionViolation(ProgrammingError,
+ code='44000', name='WITH_CHECK_OPTION_VIOLATION'):
+ pass
+
+
+# Class 53 - Insufficient Resources
+
+class InsufficientResources(OperationalError,
+ code='53000', name='INSUFFICIENT_RESOURCES'):
+ pass
+
+class DiskFull(OperationalError,
+ code='53100', name='DISK_FULL'):
+ pass
+
+class OutOfMemory(OperationalError,
+ code='53200', name='OUT_OF_MEMORY'):
+ pass
+
+class TooManyConnections(OperationalError,
+ code='53300', name='TOO_MANY_CONNECTIONS'):
+ pass
+
+class ConfigurationLimitExceeded(OperationalError,
+ code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'):
+ pass
+
+
+# Class 54 - Program Limit Exceeded
+
+class ProgramLimitExceeded(OperationalError,
+ code='54000', name='PROGRAM_LIMIT_EXCEEDED'):
+ pass
+
+class StatementTooComplex(OperationalError,
+ code='54001', name='STATEMENT_TOO_COMPLEX'):
+ pass
+
+class TooManyColumns(OperationalError,
+ code='54011', name='TOO_MANY_COLUMNS'):
+ pass
+
+class TooManyArguments(OperationalError,
+ code='54023', name='TOO_MANY_ARGUMENTS'):
+ pass
+
+
+# Class 55 - Object Not In Prerequisite State
+
+class ObjectNotInPrerequisiteState(OperationalError,
+ code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'):
+ pass
+
+class ObjectInUse(OperationalError,
+ code='55006', name='OBJECT_IN_USE'):
+ pass
+
+class CantChangeRuntimeParam(OperationalError,
+ code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'):
+ pass
+
+class LockNotAvailable(OperationalError,
+ code='55P03', name='LOCK_NOT_AVAILABLE'):
+ pass
+
+class UnsafeNewEnumValueUsage(OperationalError,
+ code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'):
+ pass
+
+
+# Class 57 - Operator Intervention
+
+class OperatorIntervention(OperationalError,
+ code='57000', name='OPERATOR_INTERVENTION'):
+ pass
+
+class QueryCanceled(OperationalError,
+ code='57014', name='QUERY_CANCELED'):
+ pass
+
+class AdminShutdown(OperationalError,
+ code='57P01', name='ADMIN_SHUTDOWN'):
+ pass
+
+class CrashShutdown(OperationalError,
+ code='57P02', name='CRASH_SHUTDOWN'):
+ pass
+
+class CannotConnectNow(OperationalError,
+ code='57P03', name='CANNOT_CONNECT_NOW'):
+ pass
+
+class DatabaseDropped(OperationalError,
+ code='57P04', name='DATABASE_DROPPED'):
+ pass
+
+class IdleSessionTimeout(OperationalError,
+ code='57P05', name='IDLE_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 58 - System Error (errors external to PostgreSQL itself)
+
+class SystemError(OperationalError,
+ code='58000', name='SYSTEM_ERROR'):
+ pass
+
+class IoError(OperationalError,
+ code='58030', name='IO_ERROR'):
+ pass
+
+class UndefinedFile(OperationalError,
+ code='58P01', name='UNDEFINED_FILE'):
+ pass
+
+class DuplicateFile(OperationalError,
+ code='58P02', name='DUPLICATE_FILE'):
+ pass
+
+
+# Class 72 - Snapshot Failure
+
+class SnapshotTooOld(DatabaseError,
+ code='72000', name='SNAPSHOT_TOO_OLD'):
+ pass
+
+
+# Class F0 - Configuration File Error
+
+class ConfigFileError(OperationalError,
+ code='F0000', name='CONFIG_FILE_ERROR'):
+ pass
+
+class LockFileExists(OperationalError,
+ code='F0001', name='LOCK_FILE_EXISTS'):
+ pass
+
+
+# Class HV - Foreign Data Wrapper Error (SQL/MED)
+
+class FdwError(OperationalError,
+ code='HV000', name='FDW_ERROR'):
+ pass
+
+class FdwOutOfMemory(OperationalError,
+ code='HV001', name='FDW_OUT_OF_MEMORY'):
+ pass
+
+class FdwDynamicParameterValueNeeded(OperationalError,
+ code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'):
+ pass
+
+class FdwInvalidDataType(OperationalError,
+ code='HV004', name='FDW_INVALID_DATA_TYPE'):
+ pass
+
+class FdwColumnNameNotFound(OperationalError,
+ code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'):
+ pass
+
+class FdwInvalidDataTypeDescriptors(OperationalError,
+ code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'):
+ pass
+
+class FdwInvalidColumnName(OperationalError,
+ code='HV007', name='FDW_INVALID_COLUMN_NAME'):
+ pass
+
+class FdwInvalidColumnNumber(OperationalError,
+ code='HV008', name='FDW_INVALID_COLUMN_NUMBER'):
+ pass
+
+class FdwInvalidUseOfNullPointer(OperationalError,
+ code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'):
+ pass
+
+class FdwInvalidStringFormat(OperationalError,
+ code='HV00A', name='FDW_INVALID_STRING_FORMAT'):
+ pass
+
+class FdwInvalidHandle(OperationalError,
+ code='HV00B', name='FDW_INVALID_HANDLE'):
+ pass
+
+class FdwInvalidOptionIndex(OperationalError,
+ code='HV00C', name='FDW_INVALID_OPTION_INDEX'):
+ pass
+
+class FdwInvalidOptionName(OperationalError,
+ code='HV00D', name='FDW_INVALID_OPTION_NAME'):
+ pass
+
+class FdwOptionNameNotFound(OperationalError,
+ code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'):
+ pass
+
+class FdwReplyHandle(OperationalError,
+ code='HV00K', name='FDW_REPLY_HANDLE'):
+ pass
+
+class FdwUnableToCreateExecution(OperationalError,
+ code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'):
+ pass
+
+class FdwUnableToCreateReply(OperationalError,
+ code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'):
+ pass
+
+class FdwUnableToEstablishConnection(OperationalError,
+ code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'):
+ pass
+
+class FdwNoSchemas(OperationalError,
+ code='HV00P', name='FDW_NO_SCHEMAS'):
+ pass
+
+class FdwSchemaNotFound(OperationalError,
+ code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'):
+ pass
+
+class FdwTableNotFound(OperationalError,
+ code='HV00R', name='FDW_TABLE_NOT_FOUND'):
+ pass
+
+class FdwFunctionSequenceError(OperationalError,
+ code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'):
+ pass
+
+class FdwTooManyHandles(OperationalError,
+ code='HV014', name='FDW_TOO_MANY_HANDLES'):
+ pass
+
+class FdwInconsistentDescriptorInformation(OperationalError,
+ code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'):
+ pass
+
+class FdwInvalidAttributeValue(OperationalError,
+ code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'):
+ pass
+
+class FdwInvalidStringLengthOrBufferLength(OperationalError,
+ code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'):
+ pass
+
+class FdwInvalidDescriptorFieldIdentifier(OperationalError,
+ code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'):
+ pass
+
+
+# Class P0 - PL/pgSQL Error
+
+class PlpgsqlError(ProgrammingError,
+ code='P0000', name='PLPGSQL_ERROR'):
+ pass
+
+class RaiseException(ProgrammingError,
+ code='P0001', name='RAISE_EXCEPTION'):
+ pass
+
+class NoDataFound(ProgrammingError,
+ code='P0002', name='NO_DATA_FOUND'):
+ pass
+
+class TooManyRows(ProgrammingError,
+ code='P0003', name='TOO_MANY_ROWS'):
+ pass
+
+class AssertFailure(ProgrammingError,
+ code='P0004', name='ASSERT_FAILURE'):
+ pass
+
+
+# Class XX - Internal Error
+
+class InternalError_(InternalError,
+ code='XX000', name='INTERNAL_ERROR'):
+ pass
+
+class DataCorrupted(InternalError,
+ code='XX001', name='DATA_CORRUPTED'):
+ pass
+
+class IndexCorrupted(InternalError,
+ code='XX002', name='INDEX_CORRUPTED'):
+ pass
+
+
+# autogenerated: end
+# fmt: on
diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py
new file mode 100644
index 0000000..584fe47
--- /dev/null
+++ b/psycopg/psycopg/generators.py
@@ -0,0 +1,320 @@
+"""
+Generators implementing communication protocols with the libpq
+
+Certain operations (connection, querying) are an interleave of libpq calls and
+waiting for the socket to be ready. This module contains the code to execute
+the operations, yielding a polling state whenever there is to wait. The
+functions in the `waiting` module are the ones who wait more or less
+cooperatively for the socket to be ready and make these generators continue.
+
+All these generators yield pairs (fileno, `Wait`) whenever an operation would
+block. The generator can be restarted sending the appropriate `Ready` state
+when the file descriptor is ready.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import List, Optional, Union
+
+from . import pq
+from . import errors as e
+from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
+from .pq.abc import PGconn, PGresult
+from .waiting import Wait, Ready
+from ._compat import Deque
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding, conninfo_encoding
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+POLL_OK = pq.PollingStatus.OK
+POLL_READING = pq.PollingStatus.READING
+POLL_WRITING = pq.PollingStatus.WRITING
+POLL_FAILED = pq.PollingStatus.FAILED
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+logger = logging.getLogger(__name__)
+
+
+def _connect(conninfo: str) -> PQGenConn[PGconn]:
+ """
+ Generator to create a database connection without blocking.
+
+ """
+ conn = pq.PGconn.connect_start(conninfo.encode())
+ while True:
+ if conn.status == BAD:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+
+ status = conn.connect_poll()
+ if status == POLL_OK:
+ break
+ elif status == POLL_READING:
+ yield conn.socket, WAIT_R
+ elif status == POLL_WRITING:
+ yield conn.socket, WAIT_W
+ elif status == POLL_FAILED:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection failed: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+ else:
+ raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn)
+
+ conn.nonblocking = 1
+ return conn
+
+
+def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator sending a query and returning results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ yield from _send(pgconn)
+ rv = yield from _fetch_many(pgconn)
+ return rv
+
+
+def _send(pgconn: PGconn) -> PQGen[None]:
+ """
+ Generator to send a query to the server without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ After this generator has finished you may want to cycle using `fetch()`
+ to retrieve the results available.
+ """
+ while True:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ ready = yield WAIT_RW
+ if ready & READY_R:
+ # This call may read notifies: they will be saved in the
+ # PGconn buffer and passed to Python later, in `fetch()`.
+ pgconn.consume_input()
+
+
+def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator retrieving results from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ results: List[PGresult] = []
+ while True:
+ res = yield from _fetch(pgconn)
+ if not res:
+ break
+
+ results.append(res)
+ status = res.status
+ if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ # After entering copy mode the libpq will create a phony result
+ # for every request so let's break the endless loop.
+ break
+
+ if status == PIPELINE_SYNC:
+ # PIPELINE_SYNC is not followed by a NULL, but we return it alone
+ # similarly to other result sets.
+ assert len(results) == 1, results
+ break
+
+ return results
+
+
+def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
+ """
+ Generator retrieving a single result from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return a result from the database (whether success or error).
+ """
+ if pgconn.is_busy():
+ yield WAIT_R
+ while True:
+ pgconn.consume_input()
+ if not pgconn.is_busy():
+ break
+ yield WAIT_R
+
+ _consume_notifies(pgconn)
+
+ return pgconn.get_result()
+
+
+def _pipeline_communicate(
+ pgconn: PGconn, commands: Deque[PipelineCommand]
+) -> PQGen[List[List[PGresult]]]:
+ """Generator to send queries from a connection in pipeline mode while also
+ receiving results.
+
+ Return a list results, including single PIPELINE_SYNC elements.
+ """
+ results = []
+
+ while True:
+ ready = yield WAIT_RW
+
+ if ready & READY_R:
+ pgconn.consume_input()
+ _consume_notifies(pgconn)
+
+ res: List[PGresult] = []
+ while not pgconn.is_busy():
+ r = pgconn.get_result()
+ if r is None:
+ if not res:
+ break
+ results.append(res)
+ res = []
+ elif r.status == PIPELINE_SYNC:
+ assert not res
+ results.append([r])
+ else:
+ res.append(r)
+
+ if ready & READY_W:
+ pgconn.flush()
+ if not commands:
+ break
+ commands.popleft()()
+
+ return results
+
+
+def _consume_notifies(pgconn: PGconn) -> None:
+ # Consume notifies
+ while True:
+ n = pgconn.notifies()
+ if not n:
+ break
+ if pgconn.notify_handler:
+ pgconn.notify_handler(n)
+
+
+def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
+ yield WAIT_R
+ pgconn.consume_input()
+
+ ns = []
+ while True:
+ n = pgconn.notifies()
+ if n:
+ ns.append(n)
+ else:
+ break
+
+ return ns
+
+
+def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
+ while True:
+ nbytes, data = pgconn.get_copy_data(1)
+ if nbytes != 0:
+ break
+
+ # would block
+ yield WAIT_R
+ pgconn.consume_input()
+
+ if nbytes > 0:
+ # some data
+ return data
+
+ # Retrieve the final result of copy
+ results = yield from _fetch_many(pgconn)
+ if len(results) > 1:
+ # TODO: too brutal? Copy worked.
+ raise e.ProgrammingError("you cannot mix COPY with other operations")
+ result = results[0]
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
+ # Retry enqueuing data until successful.
+ #
+ # WARNING! This can cause an infinite loop if the buffer is too large. (see
+ # ticket #255). We avoid it in the Copy object by splitting a large buffer
+ # into smaller ones. We prefer to do it there instead of here in order to
+ # do it upstream the queue decoupling the writer task from the producer one.
+ while pgconn.put_copy_data(buffer) == 0:
+ yield WAIT_W
+
+
+def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
+ # Retry enqueuing end copy message until successful
+ while pgconn.put_copy_end(error) == 0:
+ yield WAIT_W
+
+ # Repeat until it the message is flushed to the server
+ while True:
+ yield WAIT_W
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ # Retrieve the final result of copy
+ (result,) = yield from _fetch_many(pgconn)
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ connect = _psycopg.connect
+ execute = _psycopg.execute
+ send = _psycopg.send
+ fetch_many = _psycopg.fetch_many
+ fetch = _psycopg.fetch
+ pipeline_communicate = _psycopg.pipeline_communicate
+
+else:
+ connect = _connect
+ execute = _execute
+ send = _send
+ fetch_many = _fetch_many
+ fetch = _fetch
+ pipeline_communicate = _pipeline_communicate
diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py
new file mode 100644
index 0000000..792a9c8
--- /dev/null
+++ b/psycopg/psycopg/postgres.py
@@ -0,0 +1,125 @@
+"""
+Types configuration specific to PostgreSQL.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
+from .abc import AdaptContext
+from ._adapters_map import AdaptersMap
+
+# Global objects with PostgreSQL builtins and globally registered user types.
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+# Use tools/update_oids.py to update this data.
+for t in [
+ TypeInfo('"char"', 18, 1002),
+ # autogenerated: start
+ # Generated from PostgreSQL 15.0
+ TypeInfo("aclitem", 1033, 1034),
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("box", 603, 1020, delimiter=";"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("cid", 29, 1012),
+ TypeInfo("cidr", 650, 651),
+ TypeInfo("circle", 718, 719),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("gtsvector", 3642, 3644),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007, regtype="integer"),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("json", 114, 199),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("jsonpath", 4072, 4073),
+ TypeInfo("line", 628, 629),
+ TypeInfo("lseg", 601, 1018),
+ TypeInfo("macaddr", 829, 1040),
+ TypeInfo("macaddr8", 774, 775),
+ TypeInfo("money", 790, 791),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("path", 602, 1019),
+ TypeInfo("pg_lsn", 3220, 3221),
+ TypeInfo("point", 600, 1017),
+ TypeInfo("polygon", 604, 1027),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("refcursor", 1790, 2201),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regcollation", 4191, 4192),
+ TypeInfo("regconfig", 3734, 3735),
+ TypeInfo("regdictionary", 3769, 3770),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regoper", 2203, 2208),
+ TypeInfo("regoperator", 2204, 2209),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("tid", 27, 1010),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("tsquery", 3615, 3645),
+ TypeInfo("tsvector", 3614, 3643),
+ TypeInfo("txid_snapshot", 2970, 2949),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ TypeInfo("xid", 28, 1011),
+ TypeInfo("xid8", 5069, 271),
+ TypeInfo("xml", 142, 143),
+ RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
+ RangeInfo("int4range", 3904, 3905, subtype_oid=23),
+ RangeInfo("int8range", 3926, 3927, subtype_oid=20),
+ RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
+ RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
+ RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
+ MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
+ MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
+ MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
+ MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
+ MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
+ MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
+ # autogenerated: end
+]:
+ types.add(t)
+
+
+# A few oids used a bit everywhere
+INVALID_OID = 0
+TEXT_OID = types["text"].oid
+TEXT_ARRAY_OID = types["text"].array_oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+
+ from .types import array, bool, composite, datetime, enum, json, multirange
+ from .types import net, none, numeric, range, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ enum.register_default_adapters(context)
+ json.register_default_adapters(context)
+ multirange.register_default_adapters(context)
+ net.register_default_adapters(context)
+ none.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ range.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py
new file mode 100644
index 0000000..d5180b1
--- /dev/null
+++ b/psycopg/psycopg/pq/__init__.py
@@ -0,0 +1,133 @@
+"""
+psycopg libpq wrapper
+
+This package exposes the libpq functionalities as Python objects and functions.
+
+The real implementation (the binding to the C library) is
+implementation-dependant but all the implementations share the same interface.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import logging
+from typing import Callable, List, Type
+
+from . import abc
+from .misc import ConninfoOption, PGnotify, PGresAttDesc
+from .misc import error_message
+from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
+from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
+
+logger = logging.getLogger(__name__)
+
+__impl__: str
+"""The currently loaded implementation of the `!psycopg.pq` package.
+
+Possible values include ``python``, ``c``, ``binary``.
+"""
+
+__build_version__: int
+"""The libpq version the C package was built with.
+
+A number in the same format of `~psycopg.ConnectionInfo.server_version`
+representing the libpq used to build the speedup module (``c``, ``binary``) if
+available.
+
+Certain features might not be available if the built version is too old.
+"""
+
+version: Callable[[], int]
+PGconn: Type[abc.PGconn]
+PGresult: Type[abc.PGresult]
+Conninfo: Type[abc.Conninfo]
+Escaping: Type[abc.Escaping]
+PGcancel: Type[abc.PGcancel]
+
+
+def import_from_libpq() -> None:
+ """
+ Import pq objects implementation from the best libpq wrapper available.
+
+ If an implementation is requested try to import only it, otherwise
+ try to import the best implementation available.
+ """
+ # import these names into the module on success as side effect
+ global __impl__, version, __build_version__
+ global PGconn, PGresult, Conninfo, Escaping, PGcancel
+
+ impl = os.environ.get("PSYCOPG_IMPL", "").lower()
+ module = None
+ attempts: List[str] = []
+
+ def handle_error(name: str, e: Exception) -> None:
+ if not impl:
+ msg = f"couldn't import psycopg '{name}' implementation: {e}"
+ logger.debug(msg)
+ attempts.append(msg)
+ else:
+ msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
+ raise ImportError(msg) from e
+
+ # The best implementation: fast but requires the system libpq installed
+ if not impl or impl == "c":
+ try:
+ from psycopg_c import pq as module # type: ignore
+ except Exception as e:
+ handle_error("c", e)
+
+ # Second best implementation: fast and stand-alone
+ if not module and (not impl or impl == "binary"):
+ try:
+ from psycopg_binary import pq as module # type: ignore
+ except Exception as e:
+ handle_error("binary", e)
+
+ # Pure Python implementation, slow and requires the system libpq installed.
+ if not module and (not impl or impl == "python"):
+ try:
+ from . import pq_ctypes as module # type: ignore[no-redef]
+ except Exception as e:
+ handle_error("python", e)
+
+ if module:
+ __impl__ = module.__impl__
+ version = module.version
+ PGconn = module.PGconn
+ PGresult = module.PGresult
+ Conninfo = module.Conninfo
+ Escaping = module.Escaping
+ PGcancel = module.PGcancel
+ __build_version__ = module.__build_version__
+ elif impl:
+ raise ImportError(f"requested psycopg implementation '{impl}' unknown")
+ else:
+ sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
+ raise ImportError(
+ f"""\
+no pq wrapper available.
+Attempts made:
+{sattempts}"""
+ )
+
+
+import_from_libpq()
+
+__all__ = (
+ "ConnStatus",
+ "PipelineStatus",
+ "PollingStatus",
+ "TransactionStatus",
+ "ExecStatus",
+ "Ping",
+ "DiagnosticField",
+ "Format",
+ "Trace",
+ "PGconn",
+ "PGnotify",
+ "Conninfo",
+ "PGresAttDesc",
+ "error_message",
+ "ConninfoOption",
+ "version",
+)
diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py
new file mode 100644
index 0000000..f35d09f
--- /dev/null
+++ b/psycopg/psycopg/pq/_debug.py
@@ -0,0 +1,106 @@
+"""
+libpq debugging tools
+
+These functionalities are exposed here for convenience, but are not part of
+the public interface and are subject to change at any moment.
+
+Suggested usage::
+
+ import logging
+ import psycopg
+ from psycopg import pq
+ from psycopg.pq._debug import PGconnDebug
+
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+
+ assert pq.__impl__ == "python"
+ pq.PGconn = PGconnDebug
+
+ with psycopg.connect("") as conn:
+ conn.pgconn.trace(2)
+ conn.pgconn.set_trace_flags(
+ pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ ...
+
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import inspect
+import logging
+from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING
+from functools import wraps
+
+from . import PGconn
+from .misc import connection_summary
+
+if TYPE_CHECKING:
+ from . import abc
+
+Func = TypeVar("Func", bound=Callable[..., Any])
+
+logger = logging.getLogger("psycopg.debug")
+
+
+class PGconnDebug:
+ """Wrapper for a PQconn logging all its access."""
+
+ _Self = TypeVar("_Self", bound="PGconnDebug")
+ _pgconn: "abc.PGconn"
+
+ def __init__(self, pgconn: "abc.PGconn"):
+ super().__setattr__("_pgconn", pgconn)
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def __getattr__(self, attr: str) -> Any:
+ value = getattr(self._pgconn, attr)
+ if callable(value):
+ return debugging(value)
+ else:
+ logger.info("PGconn.%s -> %s", attr, value)
+ return value
+
+ def __setattr__(self, attr: str, value: Any) -> None:
+ setattr(self._pgconn, attr, value)
+ logger.info("PGconn.%s <- %s", attr, value)
+
+ @classmethod
+ def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect)(conninfo))
+
+ @classmethod
+ def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect_start)(conninfo))
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ return debugging(PGconn.ping)(conninfo)
+
+
+def debugging(f: Func) -> Func:
+ """Wrap a function in order to log its arguments and return value on call."""
+
+ @wraps(f)
+ def debugging_(*args: Any, **kwargs: Any) -> Any:
+ reprs = []
+ for arg in args:
+ reprs.append(f"{arg!r}")
+ for (k, v) in kwargs.items():
+ reprs.append(f"{k}={v!r}")
+
+ logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
+ rv = f(*args, **kwargs)
+ # Display the return value only if the function is declared to return
+ # something else than None.
+ ra = inspect.signature(f).return_annotation
+ if ra is not None or rv is not None:
+ logger.info(" <- %r", rv)
+ return rv
+
+ return debugging_ # type: ignore
diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py
new file mode 100644
index 0000000..e0d4018
--- /dev/null
+++ b/psycopg/psycopg/pq/_enums.py
@@ -0,0 +1,249 @@
+"""
+libpq enum definitions for psycopg
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, IntFlag, auto
+
+
+class ConnStatus(IntEnum):
+ """
+ Current status of the connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """The connection is in a working state."""
+ BAD = auto()
+ """The connection is closed."""
+
+ STARTED = auto()
+ MADE = auto()
+ AWAITING_RESPONSE = auto()
+ AUTH_OK = auto()
+ SETENV = auto()
+ SSL_STARTUP = auto()
+ NEEDED = auto()
+ CHECK_WRITABLE = auto()
+ CONSUME = auto()
+ GSS_STARTUP = auto()
+ CHECK_TARGET = auto()
+ CHECK_STANDBY = auto()
+
+
+class PollingStatus(IntEnum):
+ """
+ The status of the socket during a connection.
+
+ If ``READING`` or ``WRITING`` you may select before polling again.
+ """
+
+ __module__ = "psycopg.pq"
+
+ FAILED = 0
+ """Connection attempt failed."""
+ READING = auto()
+ """Will have to wait before reading new data."""
+ WRITING = auto()
+ """Will have to wait before writing new data."""
+ OK = auto()
+ """Connection completed."""
+
+ ACTIVE = auto()
+
+
+class ExecStatus(IntEnum):
+ """
+ The status of a command.
+ """
+
+ __module__ = "psycopg.pq"
+
+ EMPTY_QUERY = 0
+ """The string sent to the server was empty."""
+
+ COMMAND_OK = auto()
+ """Successful completion of a command returning no data."""
+
+ TUPLES_OK = auto()
+ """
+ Successful completion of a command returning data (such as a SELECT or SHOW).
+ """
+
+ COPY_OUT = auto()
+ """Copy Out (from server) data transfer started."""
+
+ COPY_IN = auto()
+ """Copy In (to server) data transfer started."""
+
+ BAD_RESPONSE = auto()
+ """The server's response was not understood."""
+
+ NONFATAL_ERROR = auto()
+ """A nonfatal error (a notice or warning) occurred."""
+
+ FATAL_ERROR = auto()
+ """A fatal error occurred."""
+
+ COPY_BOTH = auto()
+ """
+ Copy In/Out (to and from server) data transfer started.
+
+ This feature is currently used only for streaming replication, so this
+ status should not occur in ordinary applications.
+ """
+
+ SINGLE_TUPLE = auto()
+ """
+ The PGresult contains a single result tuple from the current command.
+
+ This status occurs only when single-row mode has been selected for the
+ query.
+ """
+
+ PIPELINE_SYNC = auto()
+ """
+ The PGresult represents a synchronization point in pipeline mode,
+ requested by PQpipelineSync.
+
+ This status occurs only when pipeline mode has been selected.
+ """
+
+ PIPELINE_ABORTED = auto()
+ """
+ The PGresult represents a pipeline that has received an error from the server.
+
+ PQgetResult must be called repeatedly, and each time it will return this
+ status code until the end of the current pipeline, at which point it will
+ return PGRES_PIPELINE_SYNC and normal processing can resume.
+ """
+
+
+class TransactionStatus(IntEnum):
+ """
+ The transaction status of a connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ IDLE = 0
+ """Connection ready, no transaction active."""
+
+ ACTIVE = auto()
+ """A command is in progress."""
+
+ INTRANS = auto()
+ """Connection idle in an open transaction."""
+
+ INERROR = auto()
+ """An error happened in the current transaction."""
+
+ UNKNOWN = auto()
+ """Unknown connection state, broken connection."""
+
+
+class Ping(IntEnum):
+ """Response from a ping attempt."""
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """
+ The server is running and appears to be accepting connections.
+ """
+
+ REJECT = auto()
+ """
+ The server is running but is in a state that disallows connections.
+ """
+
+ NO_RESPONSE = auto()
+ """
+ The server could not be contacted.
+ """
+
+ NO_ATTEMPT = auto()
+ """
+ No attempt was made to contact the server.
+ """
+
+
+class PipelineStatus(IntEnum):
+ """Pipeline mode status of the libpq connection."""
+
+ __module__ = "psycopg.pq"
+
+ OFF = 0
+ """
+ The libpq connection is *not* in pipeline mode.
+ """
+ ON = auto()
+ """
+ The libpq connection is in pipeline mode.
+ """
+ ABORTED = auto()
+ """
+ The libpq connection is in pipeline mode and an error occurred while
+ processing the current pipeline. The aborted flag is cleared when
+ PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
+ """
+
+
+class DiagnosticField(IntEnum):
+ """
+ Fields in an error report.
+ """
+
+ __module__ = "psycopg.pq"
+
+ # from postgres_ext.h
+ SEVERITY = ord("S")
+ SEVERITY_NONLOCALIZED = ord("V")
+ SQLSTATE = ord("C")
+ MESSAGE_PRIMARY = ord("M")
+ MESSAGE_DETAIL = ord("D")
+ MESSAGE_HINT = ord("H")
+ STATEMENT_POSITION = ord("P")
+ INTERNAL_POSITION = ord("p")
+ INTERNAL_QUERY = ord("q")
+ CONTEXT = ord("W")
+ SCHEMA_NAME = ord("s")
+ TABLE_NAME = ord("t")
+ COLUMN_NAME = ord("c")
+ DATATYPE_NAME = ord("d")
+ CONSTRAINT_NAME = ord("n")
+ SOURCE_FILE = ord("F")
+ SOURCE_LINE = ord("L")
+ SOURCE_FUNCTION = ord("R")
+
+
+class Format(IntEnum):
+ """
+ Enum representing the format of a query argument or return value.
+
+ These values are only the ones managed by the libpq. `~psycopg` may also
+ support automatically-chosen values: see `psycopg.adapt.PyFormat`.
+ """
+
+ __module__ = "psycopg.pq"
+
+ TEXT = 0
+ """Text parameter."""
+ BINARY = 1
+ """Binary parameter."""
+
+
+class Trace(IntFlag):
+ """
+ Enum to control tracing of the client/server communication.
+ """
+
+ __module__ = "psycopg.pq"
+
+ SUPPRESS_TIMESTAMPS = 1
+ """Do not include timestamps in messages."""
+
+ REGRESS_MODE = 2
+ """Redact some fields, e.g. OIDs, from messages."""
diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py
new file mode 100644
index 0000000..9ca1d12
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.py
@@ -0,0 +1,804 @@
+"""
+libpq access using ctypes
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import ctypes
+import ctypes.util
+from ctypes import Structure, CFUNCTYPE, POINTER
+from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
+from typing import List, Optional, Tuple
+
+from .misc import find_libpq_full_path
+from ..errors import NotSupportedError
+
+libname = find_libpq_full_path()
+if not libname:
+ raise ImportError("libpq library not found")
+
+pq = ctypes.cdll.LoadLibrary(libname)
+
+
+class FILE(Structure):
+ pass
+
+
+FILE_ptr = POINTER(FILE)
+
+if sys.platform == "linux":
+ libcname = ctypes.util.find_library("c")
+ assert libcname
+ libc = ctypes.cdll.LoadLibrary(libcname)
+
+ fdopen = libc.fdopen
+ fdopen.argtypes = (c_int, c_char_p)
+ fdopen.restype = FILE_ptr
+
+
+# Get the libpq version to define what functions are available.
+
+PQlibVersion = pq.PQlibVersion
+PQlibVersion.argtypes = []
+PQlibVersion.restype = c_int
+
+libpq_version = PQlibVersion()
+
+
+# libpq data types
+
+
+Oid = c_uint
+
+
+class PGconn_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresult_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PQconninfoOption_struct(Structure):
+ _fields_ = [
+ ("keyword", c_char_p),
+ ("envvar", c_char_p),
+ ("compiled", c_char_p),
+ ("val", c_char_p),
+ ("label", c_char_p),
+ ("dispchar", c_char_p),
+ ("dispsize", c_int),
+ ]
+
+
+class PGnotify_struct(Structure):
+ _fields_ = [
+ ("relname", c_char_p),
+ ("be_pid", c_int),
+ ("extra", c_char_p),
+ ]
+
+
+class PGcancel_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresAttDesc_struct(Structure):
+ _fields_ = [
+ ("name", c_char_p),
+ ("tableid", Oid),
+ ("columnid", c_int),
+ ("format", c_int),
+ ("typid", Oid),
+ ("typlen", c_int),
+ ("atttypmod", c_int),
+ ]
+
+
+PGconn_ptr = POINTER(PGconn_struct)
+PGresult_ptr = POINTER(PGresult_struct)
+PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
+PGnotify_ptr = POINTER(PGnotify_struct)
+PGcancel_ptr = POINTER(PGcancel_struct)
+PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
+
+
+# Function definitions as explained in PostgreSQL 12 documentation
+
+# 33.1. Database Connection Control Functions
+
+# PQconnectdbParams: doesn't seem useful, won't wrap for now
+
+PQconnectdb = pq.PQconnectdb
+PQconnectdb.argtypes = [c_char_p]
+PQconnectdb.restype = PGconn_ptr
+
+# PQsetdbLogin: not useful
+# PQsetdb: not useful
+
+# PQconnectStartParams: not useful
+
+PQconnectStart = pq.PQconnectStart
+PQconnectStart.argtypes = [c_char_p]
+PQconnectStart.restype = PGconn_ptr
+
+PQconnectPoll = pq.PQconnectPoll
+PQconnectPoll.argtypes = [PGconn_ptr]
+PQconnectPoll.restype = c_int
+
+PQconndefaults = pq.PQconndefaults
+PQconndefaults.argtypes = []
+PQconndefaults.restype = PQconninfoOption_ptr
+
+PQconninfoFree = pq.PQconninfoFree
+PQconninfoFree.argtypes = [PQconninfoOption_ptr]
+PQconninfoFree.restype = None
+
+PQconninfo = pq.PQconninfo
+PQconninfo.argtypes = [PGconn_ptr]
+PQconninfo.restype = PQconninfoOption_ptr
+
+PQconninfoParse = pq.PQconninfoParse
+PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
+PQconninfoParse.restype = PQconninfoOption_ptr
+
+PQfinish = pq.PQfinish
+PQfinish.argtypes = [PGconn_ptr]
+PQfinish.restype = None
+
+PQreset = pq.PQreset
+PQreset.argtypes = [PGconn_ptr]
+PQreset.restype = None
+
+PQresetStart = pq.PQresetStart
+PQresetStart.argtypes = [PGconn_ptr]
+PQresetStart.restype = c_int
+
+PQresetPoll = pq.PQresetPoll
+PQresetPoll.argtypes = [PGconn_ptr]
+PQresetPoll.restype = c_int
+
+PQping = pq.PQping
+PQping.argtypes = [c_char_p]
+PQping.restype = c_int
+
+
+# 33.2. Connection Status Functions
+
+PQdb = pq.PQdb
+PQdb.argtypes = [PGconn_ptr]
+PQdb.restype = c_char_p
+
+PQuser = pq.PQuser
+PQuser.argtypes = [PGconn_ptr]
+PQuser.restype = c_char_p
+
+PQpass = pq.PQpass
+PQpass.argtypes = [PGconn_ptr]
+PQpass.restype = c_char_p
+
+PQhost = pq.PQhost
+PQhost.argtypes = [PGconn_ptr]
+PQhost.restype = c_char_p
+
+_PQhostaddr = None
+
+if libpq_version >= 120000:
+ _PQhostaddr = pq.PQhostaddr
+ _PQhostaddr.argtypes = [PGconn_ptr]
+ _PQhostaddr.restype = c_char_p
+
+
+def PQhostaddr(pgconn: PGconn_struct) -> bytes:
+ if not _PQhostaddr:
+ raise NotSupportedError(
+ "PQhostaddr requires libpq from PostgreSQL 12,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQhostaddr(pgconn)
+
+
+PQport = pq.PQport
+PQport.argtypes = [PGconn_ptr]
+PQport.restype = c_char_p
+
+PQtty = pq.PQtty
+PQtty.argtypes = [PGconn_ptr]
+PQtty.restype = c_char_p
+
+PQoptions = pq.PQoptions
+PQoptions.argtypes = [PGconn_ptr]
+PQoptions.restype = c_char_p
+
+PQstatus = pq.PQstatus
+PQstatus.argtypes = [PGconn_ptr]
+PQstatus.restype = c_int
+
+PQtransactionStatus = pq.PQtransactionStatus
+PQtransactionStatus.argtypes = [PGconn_ptr]
+PQtransactionStatus.restype = c_int
+
+PQparameterStatus = pq.PQparameterStatus
+PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
+PQparameterStatus.restype = c_char_p
+
+PQprotocolVersion = pq.PQprotocolVersion
+PQprotocolVersion.argtypes = [PGconn_ptr]
+PQprotocolVersion.restype = c_int
+
+PQserverVersion = pq.PQserverVersion
+PQserverVersion.argtypes = [PGconn_ptr]
+PQserverVersion.restype = c_int
+
+PQerrorMessage = pq.PQerrorMessage
+PQerrorMessage.argtypes = [PGconn_ptr]
+PQerrorMessage.restype = c_char_p
+
+PQsocket = pq.PQsocket
+PQsocket.argtypes = [PGconn_ptr]
+PQsocket.restype = c_int
+
+PQbackendPID = pq.PQbackendPID
+PQbackendPID.argtypes = [PGconn_ptr]
+PQbackendPID.restype = c_int
+
+PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
+PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
+PQconnectionNeedsPassword.restype = c_int
+
+PQconnectionUsedPassword = pq.PQconnectionUsedPassword
+PQconnectionUsedPassword.argtypes = [PGconn_ptr]
+PQconnectionUsedPassword.restype = c_int
+
+PQsslInUse = pq.PQsslInUse
+PQsslInUse.argtypes = [PGconn_ptr]
+PQsslInUse.restype = c_int
+
+# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
+
+
+# 33.3. Command Execution Functions
+
+PQexec = pq.PQexec
+PQexec.argtypes = [PGconn_ptr, c_char_p]
+PQexec.restype = PGresult_ptr
+
+PQexecParams = pq.PQexecParams
+PQexecParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecParams.restype = PGresult_ptr
+
+PQprepare = pq.PQprepare
+PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQprepare.restype = PGresult_ptr
+
+PQexecPrepared = pq.PQexecPrepared
+PQexecPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecPrepared.restype = PGresult_ptr
+
+PQdescribePrepared = pq.PQdescribePrepared
+PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePrepared.restype = PGresult_ptr
+
+PQdescribePortal = pq.PQdescribePortal
+PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePortal.restype = PGresult_ptr
+
+PQresultStatus = pq.PQresultStatus
+PQresultStatus.argtypes = [PGresult_ptr]
+PQresultStatus.restype = c_int
+
+# PQresStatus: not needed, we have pretty enums
+
+PQresultErrorMessage = pq.PQresultErrorMessage
+PQresultErrorMessage.argtypes = [PGresult_ptr]
+PQresultErrorMessage.restype = c_char_p
+
+# TODO: PQresultVerboseErrorMessage
+
+PQresultErrorField = pq.PQresultErrorField
+PQresultErrorField.argtypes = [PGresult_ptr, c_int]
+PQresultErrorField.restype = c_char_p
+
+PQclear = pq.PQclear
+PQclear.argtypes = [PGresult_ptr]
+PQclear.restype = None
+
+
+# 33.3.2. Retrieving Query Result Information
+
+PQntuples = pq.PQntuples
+PQntuples.argtypes = [PGresult_ptr]
+PQntuples.restype = c_int
+
+PQnfields = pq.PQnfields
+PQnfields.argtypes = [PGresult_ptr]
+PQnfields.restype = c_int
+
+PQfname = pq.PQfname
+PQfname.argtypes = [PGresult_ptr, c_int]
+PQfname.restype = c_char_p
+
+# PQfnumber: useless and hard to use
+
+PQftable = pq.PQftable
+PQftable.argtypes = [PGresult_ptr, c_int]
+PQftable.restype = Oid
+
+PQftablecol = pq.PQftablecol
+PQftablecol.argtypes = [PGresult_ptr, c_int]
+PQftablecol.restype = c_int
+
+PQfformat = pq.PQfformat
+PQfformat.argtypes = [PGresult_ptr, c_int]
+PQfformat.restype = c_int
+
+PQftype = pq.PQftype
+PQftype.argtypes = [PGresult_ptr, c_int]
+PQftype.restype = Oid
+
+PQfmod = pq.PQfmod
+PQfmod.argtypes = [PGresult_ptr, c_int]
+PQfmod.restype = c_int
+
+PQfsize = pq.PQfsize
+PQfsize.argtypes = [PGresult_ptr, c_int]
+PQfsize.restype = c_int
+
+PQbinaryTuples = pq.PQbinaryTuples
+PQbinaryTuples.argtypes = [PGresult_ptr]
+PQbinaryTuples.restype = c_int
+
+PQgetvalue = pq.PQgetvalue
+PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
+
+PQgetisnull = pq.PQgetisnull
+PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetisnull.restype = c_int
+
+PQgetlength = pq.PQgetlength
+PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetlength.restype = c_int
+
+PQnparams = pq.PQnparams
+PQnparams.argtypes = [PGresult_ptr]
+PQnparams.restype = c_int
+
+PQparamtype = pq.PQparamtype
+PQparamtype.argtypes = [PGresult_ptr, c_int]
+PQparamtype.restype = Oid
+
+# PQprint: pretty useless
+
+# 33.3.3. Retrieving Other Result Information
+
+PQcmdStatus = pq.PQcmdStatus
+PQcmdStatus.argtypes = [PGresult_ptr]
+PQcmdStatus.restype = c_char_p
+
+PQcmdTuples = pq.PQcmdTuples
+PQcmdTuples.argtypes = [PGresult_ptr]
+PQcmdTuples.restype = c_char_p
+
+PQoidValue = pq.PQoidValue
+PQoidValue.argtypes = [PGresult_ptr]
+PQoidValue.restype = Oid
+
+
+# 33.3.4. Escaping Strings for Inclusion in SQL Commands
+
+PQescapeLiteral = pq.PQescapeLiteral
+PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeLiteral.restype = POINTER(c_char)
+
+PQescapeIdentifier = pq.PQescapeIdentifier
+PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeIdentifier.restype = POINTER(c_char)
+
+PQescapeStringConn = pq.PQescapeStringConn
+# TODO: raises "wrong type" error
+# PQescapeStringConn.argtypes = [
+# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
+# ]
+PQescapeStringConn.restype = c_size_t
+
+PQescapeString = pq.PQescapeString
+# TODO: raises "wrong type" error
+# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
+PQescapeString.restype = c_size_t
+
+PQescapeByteaConn = pq.PQescapeByteaConn
+PQescapeByteaConn.argtypes = [
+ PGconn_ptr,
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeByteaConn.restype = POINTER(c_ubyte)
+
+PQescapeBytea = pq.PQescapeBytea
+PQescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeBytea.restype = POINTER(c_ubyte)
+
+
+PQunescapeBytea = pq.PQunescapeBytea
+PQunescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ POINTER(c_size_t),
+]
+PQunescapeBytea.restype = POINTER(c_ubyte)
+
+
+# 33.4. Asynchronous Command Processing
+
+PQsendQuery = pq.PQsendQuery
+PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
+PQsendQuery.restype = c_int
+
+PQsendQueryParams = pq.PQsendQueryParams
+PQsendQueryParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryParams.restype = c_int
+
+PQsendPrepare = pq.PQsendPrepare
+PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQsendPrepare.restype = c_int
+
+PQsendQueryPrepared = pq.PQsendQueryPrepared
+PQsendQueryPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryPrepared.restype = c_int
+
+PQsendDescribePrepared = pq.PQsendDescribePrepared
+PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePrepared.restype = c_int
+
+PQsendDescribePortal = pq.PQsendDescribePortal
+PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePortal.restype = c_int
+
+PQgetResult = pq.PQgetResult
+PQgetResult.argtypes = [PGconn_ptr]
+PQgetResult.restype = PGresult_ptr
+
+PQconsumeInput = pq.PQconsumeInput
+PQconsumeInput.argtypes = [PGconn_ptr]
+PQconsumeInput.restype = c_int
+
+PQisBusy = pq.PQisBusy
+PQisBusy.argtypes = [PGconn_ptr]
+PQisBusy.restype = c_int
+
+PQsetnonblocking = pq.PQsetnonblocking
+PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
+PQsetnonblocking.restype = c_int
+
+PQisnonblocking = pq.PQisnonblocking
+PQisnonblocking.argtypes = [PGconn_ptr]
+PQisnonblocking.restype = c_int
+
+PQflush = pq.PQflush
+PQflush.argtypes = [PGconn_ptr]
+PQflush.restype = c_int
+
+
+# 33.5. Retrieving Query Results Row-by-Row
+PQsetSingleRowMode = pq.PQsetSingleRowMode
+PQsetSingleRowMode.argtypes = [PGconn_ptr]
+PQsetSingleRowMode.restype = c_int
+
+
+# 33.6. Canceling Queries in Progress
+
+PQgetCancel = pq.PQgetCancel
+PQgetCancel.argtypes = [PGconn_ptr]
+PQgetCancel.restype = PGcancel_ptr
+
+PQfreeCancel = pq.PQfreeCancel
+PQfreeCancel.argtypes = [PGcancel_ptr]
+PQfreeCancel.restype = None
+
+PQcancel = pq.PQcancel
+# TODO: raises "wrong type" error
+# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
+PQcancel.restype = c_int
+
+
+# 33.8. Asynchronous Notification
+
+PQnotifies = pq.PQnotifies
+PQnotifies.argtypes = [PGconn_ptr]
+PQnotifies.restype = PGnotify_ptr
+
+
+# 33.9. Functions Associated with the COPY Command
+
+PQputCopyData = pq.PQputCopyData
+PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
+PQputCopyData.restype = c_int
+
+PQputCopyEnd = pq.PQputCopyEnd
+PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
+PQputCopyEnd.restype = c_int
+
+PQgetCopyData = pq.PQgetCopyData
+PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
+PQgetCopyData.restype = c_int
+
+
+# 33.10. Control Functions
+
+PQtrace = pq.PQtrace
+PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
+PQtrace.restype = None
+
+_PQsetTraceFlags = None
+
+if libpq_version >= 140000:
+ _PQsetTraceFlags = pq.PQsetTraceFlags
+ _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
+ _PQsetTraceFlags.restype = None
+
+
+def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
+ if not _PQsetTraceFlags:
+ raise NotSupportedError(
+ "PQsetTraceFlags requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+
+ _PQsetTraceFlags(pgconn, flags)
+
+
+PQuntrace = pq.PQuntrace
+PQuntrace.argtypes = [PGconn_ptr]
+PQuntrace.restype = None
+
+# 33.11. Miscellaneous Functions
+
+PQfreemem = pq.PQfreemem
+PQfreemem.argtypes = [c_void_p]
+PQfreemem.restype = None
+
+if libpq_version >= 100000:
+ _PQencryptPasswordConn = pq.PQencryptPasswordConn
+ _PQencryptPasswordConn.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_char_p,
+ c_char_p,
+ ]
+ _PQencryptPasswordConn.restype = POINTER(c_char)
+
+
+def PQencryptPasswordConn(
+ pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
+) -> Optional[bytes]:
+ if not _PQencryptPasswordConn:
+ raise NotSupportedError(
+ "PQencryptPasswordConn requires libpq from PostgreSQL 10,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
+
+
+PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
+PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
+PQmakeEmptyPGresult.restype = PGresult_ptr
+
+PQsetResultAttrs = pq.PQsetResultAttrs
+PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
+PQsetResultAttrs.restype = c_int
+
+
+# 33.12. Notice Processing
+
+PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
+
+PQsetNoticeReceiver = pq.PQsetNoticeReceiver
+PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
+PQsetNoticeReceiver.restype = PQnoticeReceiver
+
+# 34.5 Pipeline Mode
+
+_PQpipelineStatus = None
+_PQenterPipelineMode = None
+_PQexitPipelineMode = None
+_PQpipelineSync = None
+_PQsendFlushRequest = None
+
+if libpq_version >= 140000:
+ _PQpipelineStatus = pq.PQpipelineStatus
+ _PQpipelineStatus.argtypes = [PGconn_ptr]
+ _PQpipelineStatus.restype = c_int
+
+ _PQenterPipelineMode = pq.PQenterPipelineMode
+ _PQenterPipelineMode.argtypes = [PGconn_ptr]
+ _PQenterPipelineMode.restype = c_int
+
+ _PQexitPipelineMode = pq.PQexitPipelineMode
+ _PQexitPipelineMode.argtypes = [PGconn_ptr]
+ _PQexitPipelineMode.restype = c_int
+
+ _PQpipelineSync = pq.PQpipelineSync
+ _PQpipelineSync.argtypes = [PGconn_ptr]
+ _PQpipelineSync.restype = c_int
+
+ _PQsendFlushRequest = pq.PQsendFlushRequest
+ _PQsendFlushRequest.argtypes = [PGconn_ptr]
+ _PQsendFlushRequest.restype = c_int
+
+
+def PQpipelineStatus(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineStatus:
+ raise NotSupportedError(
+ "PQpipelineStatus requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineStatus(pgconn)
+
+
+def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQenterPipelineMode:
+ raise NotSupportedError(
+ "PQenterPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQenterPipelineMode(pgconn)
+
+
+def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQexitPipelineMode:
+ raise NotSupportedError(
+ "PQexitPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQexitPipelineMode(pgconn)
+
+
+def PQpipelineSync(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineSync:
+ raise NotSupportedError(
+ "PQpipelineSync requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineSync(pgconn)
+
+
+def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
+ if not _PQsendFlushRequest:
+ raise NotSupportedError(
+ "PQsendFlushRequest requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQsendFlushRequest(pgconn)
+
+
+# 33.18. SSL Support
+
+PQinitOpenSSL = pq.PQinitOpenSSL
+PQinitOpenSSL.argtypes = [c_int, c_int]
+PQinitOpenSSL.restype = None
+
+
+def generate_stub() -> None:
+ import re
+ from ctypes import _CFuncPtr # type: ignore
+
+ def type2str(fname, narg, t):
+ if t is None:
+ return "None"
+ elif t is c_void_p:
+ return "Any"
+ elif t is c_int or t is c_uint or t is c_size_t:
+ return "int"
+ elif t is c_char_p or t.__name__ == "LP_c_char":
+ if narg is not None:
+ return "bytes"
+ else:
+ return "Optional[bytes]"
+
+ elif t.__name__ in (
+ "LP_PGconn_struct",
+ "LP_PGresult_struct",
+ "LP_PGcancel_struct",
+ ):
+ if narg is not None:
+ return f"Optional[{t.__name__[3:]}]"
+ else:
+ return t.__name__[3:]
+
+ elif t.__name__ in ("LP_PQconninfoOption_struct",):
+ return f"Sequence[{t.__name__[3:]}]"
+
+ elif t.__name__ in (
+ "LP_c_ubyte",
+ "LP_c_char_p",
+ "LP_c_int",
+ "LP_c_uint",
+ "LP_c_ulong",
+ "LP_FILE",
+ ):
+ return f"_Pointer[{t.__name__[3:]}]"
+
+ else:
+ assert False, f"can't deal with {t} in {fname}"
+
+ fn = __file__ + "i"
+ with open(fn) as f:
+ lines = f.read().splitlines()
+
+ istart, iend = (
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
+ )
+
+ known = {
+ line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
+ }
+
+ signatures = []
+
+ for name, obj in globals().items():
+ if name in known:
+ continue
+ if not isinstance(obj, _CFuncPtr):
+ continue
+
+ params = []
+ for i, t in enumerate(obj.argtypes):
+ params.append(f"arg{i + 1}: {type2str(name, i, t)}")
+
+ resname = type2str(name, None, obj.restype)
+
+ signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
+
+ lines[istart + 1 : iend] = signatures
+
+ with open(fn, "w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ generate_stub()
diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi
new file mode 100644
index 0000000..5d2ee3f
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.pyi
@@ -0,0 +1,216 @@
+"""
+types stub for ctypes functions
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Optional, Sequence
+from ctypes import Array, pointer, _Pointer
+from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
+
+class FILE: ...
+
+def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var]
+
+Oid = c_uint
+
+class PGconn_struct: ...
+class PGresult_struct: ...
+class PGcancel_struct: ...
+
+class PQconninfoOption_struct:
+ keyword: bytes
+ envvar: bytes
+ compiled: bytes
+ val: bytes
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+class PGnotify_struct:
+ be_pid: int
+ relname: bytes
+ extra: bytes
+
+class PGresAttDesc_struct:
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQexecPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> PGresult_struct: ...
+def PQprepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> PGresult_struct: ...
+def PQgetvalue(
+ arg1: Optional[PGresult_struct], arg2: int, arg3: int
+) -> _Pointer[c_char]: ...
+def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQescapeStringConn(
+ arg1: Optional[PGconn_struct],
+ arg2: c_char_p,
+ arg3: bytes,
+ arg4: int,
+ arg5: _Pointer[c_int],
+) -> int: ...
+def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ...
+def PQsendPrepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> int: ...
+def PQsendQueryPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> int: ...
+def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
+def PQsetNoticeReceiver(
+ arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
+) -> Callable[[Any], PGresult_struct]: ...
+
+# TODO: Ignoring type as getting an error on mypy/ctypes:
+# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be
+# a subtype of "ctypes._CData"
+def PQnotifies(
+ arg1: Optional[PGconn_struct],
+) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore
+def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
+
+# Arg 2 is a _Pointer, reported as _CArgObject by mypy
+def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
+def PQsetResultAttrs(
+ arg1: Optional[PGresult_struct],
+ arg2: int,
+ arg3: Array[PGresAttDesc_struct], # type: ignore
+) -> int: ...
+def PQtrace(
+ arg1: Optional[PGconn_struct],
+ arg2: _Pointer[FILE], # type: ignore[type-var]
+) -> None: ...
+def PQencryptPasswordConn(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: Optional[bytes],
+) -> bytes: ...
+def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ...
+
+# fmt: off
+# autogenerated: start
+def PQlibVersion() -> int: ...
+def PQconnectdb(arg1: bytes) -> PGconn_struct: ...
+def PQconnectStart(arg1: bytes) -> PGconn_struct: ...
+def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ...
+def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ...
+def PQfinish(arg1: Optional[PGconn_struct]) -> None: ...
+def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
+def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
+def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQping(arg1: bytes) -> int: ...
+def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ...
+def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
+def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ...
+def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ...
+def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
+def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
+def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
+def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
+def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ...
+def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
+def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ...
+def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
+def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
+def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
+def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
+def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
+def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
+def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ...
+def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ...
+def PQfreemem(arg1: Any) -> None: ...
+def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ...
+def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
+def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ...
+def PQinitOpenSSL(arg1: int, arg2: int) -> None: ...
+# autogenerated: end
+# fmt: on
+
+# vim: set syntax=python:
diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py
new file mode 100644
index 0000000..9c45f64
--- /dev/null
+++ b/psycopg/psycopg/pq/abc.py
@@ -0,0 +1,385 @@
+"""
+Protocol objects to represent objects exposed by different pq implementations.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from ._enums import Format, Trace
+from .._compat import Protocol
+
+if TYPE_CHECKING:
+ from .misc import PGnotify, ConninfoOption, PGresAttDesc
+
+# An object implementing the buffer protocol (ish)
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+
+class PGconn(Protocol):
+
+ notice_handler: Optional[Callable[["PGresult"], None]]
+ notify_handler: Optional[Callable[["PGnotify"], None]]
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ def connect_poll(self) -> int:
+ ...
+
+ def finish(self) -> None:
+ ...
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ ...
+
+ def reset(self) -> None:
+ ...
+
+ def reset_start(self) -> None:
+ ...
+
+ def reset_poll(self) -> int:
+ ...
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ ...
+
+ @property
+ def db(self) -> bytes:
+ ...
+
+ @property
+ def user(self) -> bytes:
+ ...
+
+ @property
+ def password(self) -> bytes:
+ ...
+
+ @property
+ def host(self) -> bytes:
+ ...
+
+ @property
+ def hostaddr(self) -> bytes:
+ ...
+
+ @property
+ def port(self) -> bytes:
+ ...
+
+ @property
+ def tty(self) -> bytes:
+ ...
+
+ @property
+ def options(self) -> bytes:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def transaction_status(self) -> int:
+ ...
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ @property
+ def server_version(self) -> int:
+ ...
+
+ @property
+ def socket(self) -> int:
+ ...
+
+ @property
+ def backend_pid(self) -> int:
+ ...
+
+ @property
+ def needs_password(self) -> bool:
+ ...
+
+ @property
+ def used_password(self) -> bool:
+ ...
+
+ @property
+ def ssl_in_use(self) -> bool:
+ ...
+
+ def exec_(self, command: bytes) -> "PGresult":
+ ...
+
+ def send_query(self, command: bytes) -> None:
+ ...
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ ...
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ ...
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ ...
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Buffer]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ ...
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ ...
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_portal(self, name: bytes) -> None:
+ ...
+
+ def get_result(self) -> Optional["PGresult"]:
+ ...
+
+ def consume_input(self) -> None:
+ ...
+
+ def is_busy(self) -> int:
+ ...
+
+ @property
+ def nonblocking(self) -> int:
+ ...
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ ...
+
+ def flush(self) -> int:
+ ...
+
+ def set_single_row_mode(self) -> None:
+ ...
+
+ def get_cancel(self) -> "PGcancel":
+ ...
+
+ def notifies(self) -> Optional["PGnotify"]:
+ ...
+
+ def put_copy_data(self, buffer: Buffer) -> int:
+ ...
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ ...
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ ...
+
+ def trace(self, fileno: int) -> None:
+ ...
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ ...
+
+ def untrace(self) -> None:
+ ...
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ ...
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ ...
+
+ @property
+ def pipeline_status(self) -> int:
+ ...
+
+ def enter_pipeline_mode(self) -> None:
+ ...
+
+ def exit_pipeline_mode(self) -> None:
+ ...
+
+ def pipeline_sync(self) -> None:
+ ...
+
+ def send_flush_request(self) -> None:
+ ...
+
+
+class PGresult(Protocol):
+ def clear(self) -> None:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def ntuples(self) -> int:
+ ...
+
+ @property
+ def nfields(self) -> int:
+ ...
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ ...
+
+ def ftable(self, column_number: int) -> int:
+ ...
+
+ def ftablecol(self, column_number: int) -> int:
+ ...
+
+ def fformat(self, column_number: int) -> int:
+ ...
+
+ def ftype(self, column_number: int) -> int:
+ ...
+
+ def fmod(self, column_number: int) -> int:
+ ...
+
+ def fsize(self, column_number: int) -> int:
+ ...
+
+ @property
+ def binary_tuples(self) -> int:
+ ...
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def nparams(self) -> int:
+ ...
+
+ def param_type(self, param_number: int) -> int:
+ ...
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ ...
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ ...
+
+ @property
+ def oid_value(self) -> int:
+ ...
+
+ def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
+ ...
+
+
+class PGcancel(Protocol):
+ def free(self) -> None:
+ ...
+
+ def cancel(self) -> None:
+ ...
+
+
+class Conninfo(Protocol):
+ @classmethod
+ def get_defaults(cls) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
+ ...
+
+
+class Escaping(Protocol):
+ def __init__(self, conn: Optional[PGconn] = None):
+ ...
+
+ def escape_literal(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_identifier(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_string(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_bytea(self, data: Buffer) -> bytes:
+ ...
+
+ def unescape_bytea(self, data: Buffer) -> bytes:
+ ...
diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py
new file mode 100644
index 0000000..3a43133
--- /dev/null
+++ b/psycopg/psycopg/pq/misc.py
@@ -0,0 +1,146 @@
+"""
+Various functionalities to make easier to work with the libpq.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import sys
+import logging
+import ctypes.util
+from typing import cast, NamedTuple, Optional, Union
+
+from .abc import PGconn, PGresult
+from ._enums import ConnStatus, TransactionStatus, PipelineStatus
+from .._compat import cache
+from .._encodings import pgconn_encoding
+
+logger = logging.getLogger("psycopg.pq")
+
+OK = ConnStatus.OK
+
+
+class PGnotify(NamedTuple):
+ relname: bytes
+ be_pid: int
+ extra: bytes
+
+
+class ConninfoOption(NamedTuple):
+ keyword: bytes
+ envvar: Optional[bytes]
+ compiled: Optional[bytes]
+ val: Optional[bytes]
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+
+class PGresAttDesc(NamedTuple):
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+
+@cache
+def find_libpq_full_path() -> Optional[str]:
+ if sys.platform == "win32":
+ libname = ctypes.util.find_library("libpq.dll")
+
+ elif sys.platform == "darwin":
+ libname = ctypes.util.find_library("libpq.dylib")
+ # (hopefully) temporary hack: libpq not in a standard place
+ # https://github.com/orgs/Homebrew/discussions/3595
+ # If pg_config is available and agrees, let's use its indications.
+ if not libname:
+ try:
+ import subprocess as sp
+
+ libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
+ libname = os.path.join(libdir, "libpq.dylib")
+ if not os.path.exists(libname):
+ libname = None
+ except Exception as ex:
+ logger.debug("couldn't use pg_config to find libpq: %s", ex)
+
+ else:
+ libname = ctypes.util.find_library("pq")
+
+ return libname
+
+
+def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
+ """
+ Return an error message from a `PGconn` or `PGresult`.
+
+ The return value is a `!str` (unlike pq data which is usually `!bytes`):
+ use the connection encoding if available, otherwise the `!encoding`
+ parameter as a fallback for decoding. Don't raise exceptions on decoding
+ errors.
+
+ """
+ bmsg: bytes
+
+ if hasattr(obj, "error_field"):
+ # obj is a PGresult
+ obj = cast(PGresult, obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ elif hasattr(obj, "error_message"):
+ # obj is a PGconn
+ if obj.status == OK:
+ encoding = pgconn_encoding(obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ else:
+ raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
+
+ if bmsg:
+ msg = bmsg.decode(encoding, "replace")
+ else:
+ msg = "no details available"
+
+ return msg
+
+
+def connection_summary(pgconn: PGconn) -> str:
+ """
+ Return summary information on a connection.
+
+ Useful for __repr__
+ """
+ parts = []
+ if pgconn.status == OK:
+ # Put together the [STATUS]
+ status = TransactionStatus(pgconn.transaction_status).name
+ if pgconn.pipeline_status:
+ status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
+
+ # Put together the (CONNECTION)
+ if not pgconn.host.startswith(b"/"):
+ parts.append(("host", pgconn.host.decode()))
+ if pgconn.port != b"5432":
+ parts.append(("port", pgconn.port.decode()))
+ if pgconn.user != pgconn.db:
+ parts.append(("user", pgconn.user.decode()))
+ parts.append(("database", pgconn.db.decode()))
+
+ else:
+ status = ConnStatus(pgconn.status).name
+
+ sparts = " ".join("%s=%s" % part for part in parts)
+ if sparts:
+ sparts = f" ({sparts})"
+ return f"[{status}]{sparts}"
diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py
new file mode 100644
index 0000000..8b87c19
--- /dev/null
+++ b/psycopg/psycopg/pq/pq_ctypes.py
@@ -0,0 +1,1086 @@
+"""
+libpq Python wrapper using ctypes bindings.
+
+Clients shouldn't use this module directly, unless for testing: they should use
+the `pq` module instead, which is in charge of choosing the best
+implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import logging
+from os import getpid
+from weakref import ref
+
+from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref
+from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import cast as t_cast, TYPE_CHECKING
+
+from .. import errors as e
+from . import _pq_ctypes as impl
+from .misc import PGnotify, ConninfoOption, PGresAttDesc
+from .misc import error_message, connection_summary
+from ._enums import Format, ExecStatus, Trace
+
+# Imported locally to call them from __del__ methods
+from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus
+
+if TYPE_CHECKING:
+ from . import abc
+
+__impl__ = "python"
+
+logger = logging.getLogger("psycopg")
+
+
+def version() -> int:
+ """Return the version number of the libpq currently loaded.
+
+ The number is in the same format of `~psycopg.ConnectionInfo.server_version`.
+
+ Certain features might not be available if the libpq library used is too old.
+ """
+ return impl.PQlibVersion()
+
+
+@impl.PQnoticeReceiver # type: ignore
+def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
+ pgconn = cast(arg, POINTER(py_object)).contents.value()
+ if not (pgconn and pgconn.notice_handler):
+ return
+
+ res = PGresult(result_ptr)
+ try:
+ pgconn.notice_handler(res)
+ except Exception as exc:
+ logger.exception("error in notice receiver: %s", exc)
+ finally:
+ res._pgresult_ptr = None # avoid destroying the pgresult_ptr
+
+
+class PGconn:
+ """
+ Python representation of a libpq connection.
+ """
+
+ __slots__ = (
+ "_pgconn_ptr",
+ "notice_handler",
+ "notify_handler",
+ "_self_ptr",
+ "_procpid",
+ "__weakref__",
+ )
+
+ def __init__(self, pgconn_ptr: impl.PGconn_struct):
+ self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
+ self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None
+ self.notify_handler: Optional[Callable[[PGnotify], None]] = None
+
+ # Keep alive for the lifetime of PGconn
+ self._self_ptr = py_object(ref(self))
+ impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr))
+
+ self._procpid = getpid()
+
+ def __del__(self) -> None:
+ # Close the connection only if it was created in this process,
+ # not if this object is being GC'd after fork.
+ if getpid() == self._procpid:
+ self.finish()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectdb(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectStart(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ def connect_poll(self) -> int:
+ return self._call_int(impl.PQconnectPoll)
+
+ def finish(self) -> None:
+ self._pgconn_ptr, p = None, self._pgconn_ptr
+ if p:
+ PQfinish(p)
+
+ @property
+ def pgconn_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGconn` structure, as integer.
+
+ `!None` if the connection is closed.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgconn_ptr is None:
+ return None
+
+ return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ self._ensure_pgconn()
+ opts = impl.PQconninfo(self._pgconn_ptr)
+ if not opts:
+ raise MemoryError("couldn't allocate connection info")
+ try:
+ return Conninfo._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ def reset(self) -> None:
+ self._ensure_pgconn()
+ impl.PQreset(self._pgconn_ptr)
+
+ def reset_start(self) -> None:
+ if not impl.PQresetStart(self._pgconn_ptr):
+ raise e.OperationalError("couldn't reset connection")
+
+ def reset_poll(self) -> int:
+ return self._call_int(impl.PQresetPoll)
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ return impl.PQping(conninfo)
+
+ @property
+ def db(self) -> bytes:
+ return self._call_bytes(impl.PQdb)
+
+ @property
+ def user(self) -> bytes:
+ return self._call_bytes(impl.PQuser)
+
+ @property
+ def password(self) -> bytes:
+ return self._call_bytes(impl.PQpass)
+
+ @property
+ def host(self) -> bytes:
+ return self._call_bytes(impl.PQhost)
+
+ @property
+ def hostaddr(self) -> bytes:
+ return self._call_bytes(impl.PQhostaddr)
+
+ @property
+ def port(self) -> bytes:
+ return self._call_bytes(impl.PQport)
+
+ @property
+ def tty(self) -> bytes:
+ return self._call_bytes(impl.PQtty)
+
+ @property
+ def options(self) -> bytes:
+ return self._call_bytes(impl.PQoptions)
+
+ @property
+ def status(self) -> int:
+ return PQstatus(self._pgconn_ptr)
+
+ @property
+ def transaction_status(self) -> int:
+ return impl.PQtransactionStatus(self._pgconn_ptr)
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ self._ensure_pgconn()
+ return impl.PQparameterStatus(self._pgconn_ptr, name)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQerrorMessage(self._pgconn_ptr)
+
+ @property
+ def protocol_version(self) -> int:
+ return self._call_int(impl.PQprotocolVersion)
+
+ @property
+ def server_version(self) -> int:
+ return self._call_int(impl.PQserverVersion)
+
+ @property
+ def socket(self) -> int:
+ rv = self._call_int(impl.PQsocket)
+ if rv == -1:
+ raise e.OperationalError("the connection is lost")
+ return rv
+
+ @property
+ def backend_pid(self) -> int:
+ return self._call_int(impl.PQbackendPID)
+
+ @property
+ def needs_password(self) -> bool:
+ """True if the connection authentication method required a password,
+ but none was available.
+
+ See :pq:`PQconnectionNeedsPassword` for details.
+ """
+ return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr))
+
+ @property
+ def used_password(self) -> bool:
+ """True if the connection authentication method used a password.
+
+ See :pq:`PQconnectionUsedPassword` for details.
+ """
+ return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr))
+
+ @property
+ def ssl_in_use(self) -> bool:
+ return self._call_bool(impl.PQsslInUse)
+
+ def exec_(self, command: bytes) -> "PGresult":
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQexec(self._pgconn_ptr, command)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query(self, command: bytes) -> None:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendQuery(self._pgconn_ptr, command):
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ rv = impl.PQexecParams(*args)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ if not impl.PQsendQueryParams(*args):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ # repurpose this function with a cheeky replacement of query with name,
+ # drop the param_types from the result
+ args = self._query_params_args(
+ name, param_values, None, param_formats, result_format
+ )
+ args = args[:3] + args[4:]
+
+ self._ensure_pgconn()
+ if not impl.PQsendQueryPrepared(*args):
+ raise e.OperationalError(
+ f"sending prepared query failed: {error_message(self)}"
+ )
+
+ def _query_params_args(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> Any:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ atypes = None
+ else:
+ if len(param_types) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_types))
+ )
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_formats"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ return (
+ self._pgconn_ptr,
+ command,
+ nparams,
+ atypes,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ if not isinstance(command, bytes):
+ raise TypeError(f"'command' must be bytes, got {type(command)} instead")
+
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence["abc.Buffer"]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ self._ensure_pgconn()
+ rv = impl.PQexecPrepared(
+ self._pgconn_ptr,
+ name,
+ nparams,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePrepared(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePrepared(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePortal(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_portal(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePortal(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe portal failed: {error_message(self)}"
+ )
+
+ def get_result(self) -> Optional["PGresult"]:
+ rv = impl.PQgetResult(self._pgconn_ptr)
+ return PGresult(rv) if rv else None
+
+ def consume_input(self) -> None:
+ if 1 != impl.PQconsumeInput(self._pgconn_ptr):
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
+
+ def is_busy(self) -> int:
+ return impl.PQisBusy(self._pgconn_ptr)
+
+ @property
+ def nonblocking(self) -> int:
+ return impl.PQisnonblocking(self._pgconn_ptr)
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg):
+ raise e.OperationalError(
+ f"setting nonblocking failed: {error_message(self)}"
+ )
+
+ def flush(self) -> int:
+ # PQflush segfaults if it receives a NULL connection
+ if not self._pgconn_ptr:
+ raise e.OperationalError("flushing failed: the connection is closed")
+ rv: int = impl.PQflush(self._pgconn_ptr)
+ if rv < 0:
+ raise e.OperationalError(f"flushing failed: {error_message(self)}")
+ return rv
+
+ def set_single_row_mode(self) -> None:
+ if not impl.PQsetSingleRowMode(self._pgconn_ptr):
+ raise e.OperationalError("setting single row mode failed")
+
+ def get_cancel(self) -> "PGcancel":
+ """
+ Create an object with the information needed to cancel a command.
+
+ See :pq:`PQgetCancel` for details.
+ """
+ rv = impl.PQgetCancel(self._pgconn_ptr)
+ if not rv:
+ raise e.OperationalError("couldn't create cancel object")
+ return PGcancel(rv)
+
+ def notifies(self) -> Optional[PGnotify]:
+ ptr = impl.PQnotifies(self._pgconn_ptr)
+ if ptr:
+ c = ptr.contents
+ return PGnotify(c.relname, c.be_pid, c.extra)
+ impl.PQfreemem(ptr)
+ else:
+ return None
+
+ def put_copy_data(self, buffer: "abc.Buffer") -> int:
+ if not isinstance(buffer, bytes):
+ buffer = bytes(buffer)
+ rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
+ if rv < 0:
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
+ return rv
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ buffer_ptr = c_char_p()
+ nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_)
+ if nbytes == -2:
+ raise e.OperationalError(
+ f"receiving copy data failed: {error_message(self)}"
+ )
+ if buffer_ptr:
+ # TODO: do it without copy
+ data = string_at(buffer_ptr, nbytes)
+ impl.PQfreemem(buffer_ptr)
+ return nbytes, memoryview(data)
+ else:
+ return nbytes, memoryview(b"")
+
+ def trace(self, fileno: int) -> None:
+ """
+ Enable tracing of the client/server communication to a file stream.
+
+ See :pq:`PQtrace` for details.
+ """
+ if sys.platform != "linux":
+ raise e.NotSupportedError("currently only supported on Linux")
+ stream = impl.fdopen(fileno, b"w")
+ impl.PQtrace(self._pgconn_ptr, stream)
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ """
+ Configure tracing behavior of client/server communication.
+
+ :param flags: operating mode of tracing.
+
+ See :pq:`PQsetTraceFlags` for details.
+ """
+ impl.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+ def untrace(self) -> None:
+ """
+ Disable tracing, previously enabled through `trace()`.
+
+ See :pq:`PQuntrace` for details.
+ """
+ impl.PQuntrace(self._pgconn_ptr)
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ """
+ Return the encrypted form of a PostgreSQL password.
+
+ See :pq:`PQencryptPasswordConn` for details.
+ """
+ out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm)
+ if not out:
+ raise e.OperationalError(
+ f"password encryption failed: {error_message(self)}"
+ )
+
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)
+ if not rv:
+ raise MemoryError("couldn't allocate empty PGresult")
+ return PGresult(rv)
+
+ @property
+ def pipeline_status(self) -> int:
+ if version() < 140000:
+ return 0
+ return impl.PQpipelineStatus(self._pgconn_ptr)
+
+ def enter_pipeline_mode(self) -> None:
+ """Enter pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to enter the pipeline
+ mode.
+ """
+ if impl.PQenterPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError("failed to enter pipeline mode")
+
+ def exit_pipeline_mode(self) -> None:
+ """Exit pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to exit the pipeline
+ mode.
+ """
+ if impl.PQexitPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError(error_message(self))
+
+ def pipeline_sync(self) -> None:
+ """Mark a synchronization point in a pipeline.
+
+ :raises ~e.OperationalError: if the connection is not in pipeline mode
+ or if sync failed.
+ """
+ rv = impl.PQpipelineSync(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError("connection not in pipeline mode")
+ if rv != 1:
+ raise e.OperationalError("failed to sync pipeline")
+
+ def send_flush_request(self) -> None:
+ """Sends a request for the server to flush its output buffer.
+
+ :raises ~e.OperationalError: if the flush request failed.
+ """
+ if impl.PQsendFlushRequest(self._pgconn_ptr) == 0:
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
+
+ def _call_bytes(
+ self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
+ ) -> bytes:
+ """
+ Call one of the pgconn libpq functions returning a bytes pointer.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ rv = func(self._pgconn_ptr)
+ assert rv is not None
+ return rv
+
+ def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int:
+ """
+ Call one of the pgconn libpq functions returning an int.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return func(self._pgconn_ptr)
+
+ def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool:
+ """
+ Call one of the pgconn libpq functions returning a logical value.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return bool(func(self._pgconn_ptr))
+
+ def _ensure_pgconn(self) -> None:
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+
+
+class PGresult:
+ """
+ Python representation of a libpq result.
+ """
+
+ __slots__ = ("_pgresult_ptr",)
+
+ def __init__(self, pgresult_ptr: impl.PGresult_struct):
+ self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr
+
+ def __del__(self) -> None:
+ self.clear()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ status = ExecStatus(self.status)
+ return f"<{cls} [{status.name}] at 0x{id(self):x}>"
+
+ def clear(self) -> None:
+ self._pgresult_ptr, p = None, self._pgresult_ptr
+ if p:
+ PQclear(p)
+
+ @property
+ def pgresult_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGresult` structure, as integer.
+
+ `!None` if the result was cleared.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgresult_ptr is None:
+ return None
+
+ return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def status(self) -> int:
+ return impl.PQresultStatus(self._pgresult_ptr)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQresultErrorMessage(self._pgresult_ptr)
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ return impl.PQresultErrorField(self._pgresult_ptr, fieldcode)
+
+ @property
+ def ntuples(self) -> int:
+ return impl.PQntuples(self._pgresult_ptr)
+
+ @property
+ def nfields(self) -> int:
+ return impl.PQnfields(self._pgresult_ptr)
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ return impl.PQfname(self._pgresult_ptr, column_number)
+
+ def ftable(self, column_number: int) -> int:
+ return impl.PQftable(self._pgresult_ptr, column_number)
+
+ def ftablecol(self, column_number: int) -> int:
+ return impl.PQftablecol(self._pgresult_ptr, column_number)
+
+ def fformat(self, column_number: int) -> int:
+ return impl.PQfformat(self._pgresult_ptr, column_number)
+
+ def ftype(self, column_number: int) -> int:
+ return impl.PQftype(self._pgresult_ptr, column_number)
+
+ def fmod(self, column_number: int) -> int:
+ return impl.PQfmod(self._pgresult_ptr, column_number)
+
+ def fsize(self, column_number: int) -> int:
+ return impl.PQfsize(self._pgresult_ptr, column_number)
+
+ @property
+ def binary_tuples(self) -> int:
+ return impl.PQbinaryTuples(self._pgresult_ptr)
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number)
+ if length:
+ v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
+ return string_at(v, length)
+ else:
+ if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
+ return None
+ else:
+ return b""
+
+ @property
+ def nparams(self) -> int:
+ return impl.PQnparams(self._pgresult_ptr)
+
+ def param_type(self, param_number: int) -> int:
+ return impl.PQparamtype(self._pgresult_ptr, param_number)
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ return impl.PQcmdStatus(self._pgresult_ptr)
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ rv = impl.PQcmdTuples(self._pgresult_ptr)
+ return int(rv) if rv else None
+
+ @property
+ def oid_value(self) -> int:
+ return impl.PQoidValue(self._pgresult_ptr)
+
+ def set_attributes(self, descriptions: List[PGresAttDesc]) -> None:
+ structs = [
+ impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore
+ ]
+ array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore
+ rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
+ if rv == 0:
+ raise e.OperationalError("PQsetResultAttrs failed")
+
+
+class PGcancel:
+ """
+ Token to cancel the current operation on a connection.
+
+ Created by `PGconn.get_cancel()`.
+ """
+
+ __slots__ = ("pgcancel_ptr",)
+
+ def __init__(self, pgcancel_ptr: impl.PGcancel_struct):
+ self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr
+
+ def __del__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ """
+ Free the data structure created by :pq:`PQgetCancel()`.
+
+ Automatically invoked by `!__del__()`.
+
+ See :pq:`PQfreeCancel()` for details.
+ """
+ self.pgcancel_ptr, p = None, self.pgcancel_ptr
+ if p:
+ PQfreeCancel(p)
+
+ def cancel(self) -> None:
+ """Requests that the server abandon processing of the current command.
+
+ See :pq:`PQcancel()` for details.
+ """
+ buf = create_string_buffer(256)
+ res = impl.PQcancel(
+ self.pgcancel_ptr,
+ byref(buf), # type: ignore[arg-type]
+ len(buf),
+ )
+ if not res:
+ raise e.OperationalError(
+ f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
+ )
+
+
+class Conninfo:
+ """
+ Utility object to manipulate connection strings.
+ """
+
+ @classmethod
+ def get_defaults(cls) -> List[ConninfoOption]:
+ opts = impl.PQconndefaults()
+ if not opts:
+ raise MemoryError("couldn't allocate connection defaults")
+ try:
+ return cls._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List[ConninfoOption]:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ errmsg = c_char_p()
+ rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type]
+ if not rv:
+ if not errmsg:
+ raise MemoryError("couldn't allocate on conninfo parse")
+ else:
+ exc = e.OperationalError(
+ (errmsg.value or b"").decode("utf8", "replace")
+ )
+ impl.PQfreemem(errmsg)
+ raise exc
+
+ try:
+ return cls._options_from_array(rv)
+ finally:
+ impl.PQconninfoFree(rv)
+
+ @classmethod
+ def _options_from_array(
+ cls, opts: Sequence[impl.PQconninfoOption_struct]
+ ) -> List[ConninfoOption]:
+ rv = []
+ skws = "keyword envvar compiled val label dispchar".split()
+ for opt in opts:
+ if not opt.keyword:
+ break
+ d = {kw: getattr(opt, kw) for kw in skws}
+ d["dispsize"] = opt.dispsize
+ rv.append(ConninfoOption(**d))
+
+ return rv
+
+
+class Escaping:
+ """
+ Utility object to escape strings for SQL interpolation.
+ """
+
+ def __init__(self, conn: Optional[PGconn] = None):
+ self.conn = conn
+
+ def escape_literal(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_literal failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+ # TODO: might be done without copy (however C does that)
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_literal failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_identifier(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_identifier failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_identifier failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_string(self, data: "abc.Buffer") -> bytes:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ if self.conn:
+ self.conn._ensure_pgconn()
+ error = c_int()
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeStringConn(
+ self.conn._pgconn_ptr,
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ byref(error), # type: ignore[arg-type]
+ )
+
+ if error:
+ raise e.OperationalError(
+ f"escape_string failed: {error_message(self.conn)} bytes"
+ )
+
+ else:
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeString(
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ )
+
+ return out.value
+
+ def escape_bytea(self, data: "abc.Buffer") -> bytes:
+ len_out = c_size_t()
+ # TODO: might be able to do without a copy but it's a mess.
+ # the C library does it better anyway, so maybe not worth optimising
+ # https://mail.python.org/pipermail/python-dev/2012-September/121780.html
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ if self.conn:
+ self.conn._ensure_pgconn()
+ out = impl.PQescapeByteaConn(
+ self.conn._pgconn_ptr,
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ else:
+ out = impl.PQescapeBytea(
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value - 1) # out includes final 0
+ impl.PQfreemem(out)
+ return rv
+
+ def unescape_bytea(self, data: "abc.Buffer") -> bytes:
+ # not needed, but let's keep it symmetric with the escaping:
+ # if a connection is passed in, it must be valid.
+ if self.conn:
+ self.conn._ensure_pgconn()
+
+ len_out = c_size_t()
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQunescapeBytea(
+ data,
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value)
+ impl.PQfreemem(out)
+ return rv
+
+
+# importing the ssl module sets up Python's libcrypto callbacks
+import ssl # noqa
+
+# disable libcrypto setup in libpq, so it won't stomp on the callbacks
+# that have already been set up
+impl.PQinitOpenSSL(1, 0)
+
+__build_version__ = version()
diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg/psycopg/py.typed
diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py
new file mode 100644
index 0000000..cb28b57
--- /dev/null
+++ b/psycopg/psycopg/rows.py
@@ -0,0 +1,256 @@
+"""
+psycopg row factories
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import functools
+from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
+from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
+from collections import namedtuple
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from ._compat import Protocol
+from ._encodings import _as_python_identifier
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from psycopg.pq.abc import PGresult
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+
+T = TypeVar("T", covariant=True)
+
+# Row factories
+
+Row = TypeVar("Row", covariant=True)
+
+
+class RowMaker(Protocol[Row]):
+ """
+ Callable protocol taking a sequence of value and returning an object.
+
+ The sequence of value is what is returned from a database query, already
+ adapted to the right Python types. The return value is the object that your
+ program would like to receive: by default (`tuple_row()`) it is a simple
+ tuple, but it may be any type of object.
+
+ Typically, `!RowMaker` functions are returned by `RowFactory`.
+ """
+
+ def __call__(self, __values: Sequence[Any]) -> Row:
+ ...
+
+
+class RowFactory(Protocol[Row]):
+ """
+ Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
+
+ A `!RowFactory` is typically called when a `!Cursor` receives a result.
+ This way it can inspect the cursor state (for instance the
+ `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
+ a complete object.
+
+ For instance the `dict_row()` `!RowFactory` uses the names of the column to
+ define the dictionary key and returns a `!RowMaker` function which would
+ use the values to create a dictionary for each record.
+ """
+
+ def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class AsyncRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking an async cursor as argument.
+ """
+
+ def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class BaseRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking either type of cursor as argument.
+ """
+
+ def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
+ ...
+
+
+TupleRow: TypeAlias = Tuple[Any, ...]
+"""
+An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
+"""
+
+
+DictRow: TypeAlias = Dict[str, Any]
+"""
+An alias for the type returned by `dict_row()`
+
+A `!DictRow` is a dictionary with keys as string and any value returned by the
+database.
+"""
+
+
+def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
+ r"""Row factory to represent rows as simple tuples.
+
+ This is the default factory, used when `~psycopg.Connection.connect()` or
+ `~psycopg.Connection.cursor()` are called without a `!row_factory`
+ parameter.
+
+ """
+ # Implementation detail: make sure this is the tuple type itself, not an
+ # equivalent function, because the C code fast-paths on it.
+ return tuple
+
+
+def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
+ """Row factory to represent rows as dictionaries.
+
+ The dictionary keys are taken from the column names of the returned columns.
+ """
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
+ # https://github.com/python/mypy/issues/2608
+ return dict(zip(names, values)) # type: ignore[arg-type]
+
+ return dict_row_
+
+
+def namedtuple_row(
+ cursor: "BaseCursor[Any, Any]",
+) -> "RowMaker[NamedTuple]":
+ """Row factory to represent rows as `~collections.namedtuple`.
+
+ The field names are taken from the column names of the returned columns,
+ with some mangling to deal with invalid names.
+ """
+ res = cursor.pgresult
+ if not res:
+ return no_result
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return no_result
+
+ nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
+ return nt._make
+
+
+@functools.lru_cache(512)
+def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
+ snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
+ return namedtuple("Row", snames) # type: ignore[return-value]
+
+
+def class_row(cls: Type[T]) -> BaseRowFactory[T]:
+ r"""Generate a row factory to represent rows as instances of the class `!cls`.
+
+ The class must support every output column name as a keyword parameter.
+
+ :param cls: The class to return for each row. It must support the fields
+ returned by the query as keyword arguments.
+ :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
+ """
+
+ def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def class_row__(values: Sequence[Any]) -> T:
+ return cls(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return class_row__
+
+ return class_row_
+
+
+def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with positional parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as positional arguments.
+ """
+
+ def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ def args_row__(values: Sequence[Any]) -> T:
+ return func(*values)
+
+ return args_row__
+
+ return args_row_
+
+
+def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with keyword parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as keyword arguments.
+ """
+
+ def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def kwargs_row__(values: Sequence[Any]) -> T:
+ return func(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return kwargs_row__
+
+ return kwargs_row_
+
+
+def no_result(values: Sequence[Any]) -> NoReturn:
+ """A `RowMaker` that always fail.
+
+ It can be used as return value for a `RowFactory` called with no result.
+ Note that the `!RowFactory` *will* be called with no result, but the
+ resulting `!RowMaker` never should.
+ """
+ raise e.InterfaceError("the cursor doesn't have a result")
+
+
+def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
+ res = cursor.pgresult
+ if not res:
+ return None
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return None
+
+ enc = cursor._encoding
+ return [
+ res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
+ ]
+
+
+def _get_nfields(res: "PGresult") -> Optional[int]:
+ """
+ Return the number of columns in a result, if it returns tuples else None
+
+ Take into account the special case of results with zero columns.
+ """
+ nfields = res.nfields
+
+ if (
+ res.status == TUPLES_OK
+ or res.status == SINGLE_TUPLE
+ # "describe" in named cursors
+ or (res.status == COMMAND_OK and nfields)
+ ):
+ return nfields
+ else:
+ return None
diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py
new file mode 100644
index 0000000..b890d77
--- /dev/null
+++ b/psycopg/psycopg/server_cursor.py
@@ -0,0 +1,479 @@
+"""
+psycopg server-side cursor objects.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, AsyncIterator, List, Iterable, Iterator
+from typing import Optional, TypeVar, TYPE_CHECKING, overload
+from warnings import warn
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .rows import Row, RowFactory, AsyncRowFactory
+from .cursor import BaseCursor, Cursor
+from .generators import execute
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+DEFAULT_ITERSIZE = 100
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+
+class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
+ """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
+
+ __slots__ = "_name _scrollable _withhold _described itersize _format".split()
+
+ def __init__(
+ self,
+ name: str,
+ scrollable: Optional[bool],
+ withhold: bool,
+ ):
+ self._name = name
+ self._scrollable = scrollable
+ self._withhold = withhold
+ self._described = False
+ self.itersize: int = DEFAULT_ITERSIZE
+ self._format = TEXT
+
+ def __repr__(self) -> str:
+ # Insert the name as the second word
+ parts = super().__repr__().split(None, 1)
+ parts.insert(1, f"{self._name!r}")
+ return " ".join(parts)
+
+ @property
+ def name(self) -> str:
+ """The name of the cursor."""
+ return self._name
+
+ @property
+ def scrollable(self) -> Optional[bool]:
+ """
+ Whether the cursor is scrollable or not.
+
+ If `!None` leave the choice to the server. Use `!True` if you want to
+ use `scroll()` on the cursor.
+ """
+ return self._scrollable
+
+ @property
+ def withhold(self) -> bool:
+ """
+ If the cursor can be used after the creating transaction has committed.
+ """
+ return self._withhold
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ res = self.pgresult
+ # command_status is empty if the result comes from
+ # describe_portal, which means that we have just executed the DECLARE,
+ # so we can assume we are at the first row.
+ tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
+ return self._pos if tuples else None
+
+ def _declare_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `ServerCursor.execute()`."""
+
+ query = self._make_declare_statement(query)
+
+ # If the cursor is being reused, the previous one must be closed.
+ if self._described:
+ yield from self._close_gen()
+ self._described = False
+
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, force_extended=True)
+ results = yield from execute(self._conn.pgconn)
+ if results[-1].status != COMMAND_OK:
+ self._raise_for_result(results[-1])
+
+ # Set the format, which will be used by describe and fetch operations
+ if binary is None:
+ self._format = self.format
+ else:
+ self._format = BINARY if binary else TEXT
+
+ # The above result only returned COMMAND_OK. Get the cursor shape
+ yield from self._describe_gen()
+
+ def _describe_gen(self) -> PQGen[None]:
+ self._pgconn.send_describe_portal(self._name.encode(self._encoding))
+ results = yield from execute(self._pgconn)
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0, format=self._format)
+ self._described = True
+
+ def _close_gen(self) -> PQGen[None]:
+ ts = self._conn.pgconn.transaction_status
+
+ # if the connection is not in a sane state, don't even try
+ if ts != IDLE and ts != INTRANS:
+ return
+
+ # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
+ if not self._withhold and ts == IDLE:
+ return
+
+ # if we didn't declare the cursor ourselves we still have to close it
+ # but we must make sure it exists.
+ if not self._described:
+ query = sql.SQL(
+ "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
+ ).format(sql.Literal(self._name))
+ res = yield from self._conn._exec_command(query)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+ if res.ntuples == 0:
+ return
+
+ query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
+ yield from self._conn._exec_command(query)
+
+ def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ # If we are stealing the cursor, make sure we know its shape
+ if not self._described:
+ yield from self._start_query()
+ yield from self._describe_gen()
+
+ query = sql.SQL("FETCH FORWARD {} FROM {}").format(
+ sql.SQL("ALL") if num is None else sql.Literal(num),
+ sql.Identifier(self._name),
+ )
+ res = yield from self._conn._exec_command(query, result_format=self._format)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=False)
+ return self._tx.load_rows(0, res.ntuples, self._make_row)
+
+ def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+ if mode not in ("relative", "absolute"):
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ query = sql.SQL("MOVE{} {} FROM {}").format(
+ sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
+ sql.Literal(value),
+ sql.Identifier(self._name),
+ )
+ yield from self._conn._exec_command(query)
+
+ def _make_declare_statement(self, query: Query) -> sql.Composed:
+
+ if isinstance(query, bytes):
+ query = query.decode(self._encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
+
+ parts = [
+ sql.SQL("DECLARE"),
+ sql.Identifier(self._name),
+ ]
+ if self._scrollable is not None:
+ parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
+ parts.append(sql.SQL("CURSOR"))
+ if self._withhold:
+ parts.append(sql.SQL("WITH HOLD"))
+ parts.append(sql.SQL("FOR"))
+ parts.append(query)
+
+ return sql.SQL(" ").join(parts)
+
+
+class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="ServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: RowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ Cursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ self._conn.wait(self._close_gen())
+ super().close()
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ """
+ Open a cursor to execute a query to the database.
+ """
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ """Method not implemented for server-side cursors."""
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ def fetchone(self) -> Optional[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ def fetchall(self) -> List[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ def __iter__(self) -> Iterator[Row]:
+ while True:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ with self._conn.lock:
+ self._conn.wait(self._scroll_gen(value, mode))
+ # Postgres doesn't have a reliable way to report a cursor out of bound
+ if mode == "relative":
+ self._pos += value
+ else:
+ self._pos = value
+
+
+class AsyncServerCursor(
+ ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: AsyncRowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ AsyncCursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ async def close(self) -> None:
+ async with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ await self._conn.wait(self._close_gen())
+ await super().close()
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ async def fetchone(self) -> Optional[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ async def fetchall(self) -> List[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ while True:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ async with self._conn.lock:
+ await self._conn.wait(self._scroll_gen(value, mode))
diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py
new file mode 100644
index 0000000..099a01c
--- /dev/null
+++ b/psycopg/psycopg/sql.py
@@ -0,0 +1,467 @@
+"""
+SQL composition utility module
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+import string
+from abc import ABC, abstractmethod
+from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
+
+from .pq import Escaping
+from .abc import AdaptContext
+from .adapt import Transformer, PyFormat
+from ._compat import LiteralString
+from ._encodings import conn_encoding
+
+
+def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
+ """
+ Adapt a Python object to a quoted SQL string.
+
+ Use this function only if you absolutely want to convert a Python string to
+ an SQL quoted literal to use e.g. to generate batch SQL and you won't have
+ a connection available when you will need to use it.
+
+ This function is relatively inefficient, because it doesn't cache the
+ adaptation rules. If you pass a `!context` you can adapt the adaptation
+ rules used, otherwise only global rules are used.
+
+ """
+ return Literal(obj).as_string(context)
+
+
+class Composable(ABC):
+ """
+ Abstract base class for objects that can be used to compose an SQL string.
+
+ `!Composable` objects can be passed directly to
+ `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
+ `~psycopg.Cursor.copy()` in place of the query string.
+
+ `!Composable` objects can be joined using the ``+`` operator: the result
+ will be a `Composed` instance containing the objects joined. The operator
+ ``*`` is also supported with an integer argument: the result is a
+ `!Composed` instance containing the left argument repeated as many times as
+ requested.
+ """
+
+ def __init__(self, obj: Any):
+ self._obj = obj
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._obj!r})"
+
+ @abstractmethod
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ """
+ Return the value of the object as bytes.
+
+ :param context: the context to evaluate the object into.
+ :type context: `connection` or `cursor`
+
+ The method is automatically invoked by `~psycopg.Cursor.execute()`,
+ `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
+ `!Composable` is passed instead of the query string.
+
+ """
+ raise NotImplementedError
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ """
+ Return the value of the object as string.
+
+ :param context: the context to evaluate the string into.
+ :type context: `connection` or `cursor`
+
+ """
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ b = self.as_bytes(context)
+ if isinstance(b, bytes):
+ return b.decode(enc)
+ else:
+ # buffer object
+ return codecs.lookup(enc).decode(b)[0]
+
+ def __add__(self, other: "Composable") -> "Composed":
+ if isinstance(other, Composed):
+ return Composed([self]) + other
+ if isinstance(other, Composable):
+ return Composed([self]) + Composed([other])
+ else:
+ return NotImplemented
+
+ def __mul__(self, n: int) -> "Composed":
+ return Composed([self] * n)
+
+ def __eq__(self, other: Any) -> bool:
+ return type(self) is type(other) and self._obj == other._obj
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Composed(Composable):
+ """
+ A `Composable` object made of a sequence of `!Composable`.
+
+ The object is usually created using `!Composable` operators and methods.
+ However it is possible to create a `!Composed` directly specifying a
+ sequence of objects as arguments: if they are not `!Composable` they will
+ be wrapped in a `Literal`.
+
+ Example::
+
+ >>> comp = sql.Composed(
+ ... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
+ >>> print(comp.as_string(conn))
+ INSERT INTO "table"
+
+ `!Composed` objects are iterable (so they can be used in `SQL.join` for
+ instance).
+ """
+
+ _obj: List[Composable]
+
+ def __init__(self, seq: Sequence[Any]):
+ seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
+ super().__init__(seq)
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ return b"".join(obj.as_bytes(context) for obj in self._obj)
+
+ def __iter__(self) -> Iterator[Composable]:
+ return iter(self._obj)
+
+ def __add__(self, other: Composable) -> "Composed":
+ if isinstance(other, Composed):
+ return Composed(self._obj + other._obj)
+ if isinstance(other, Composable):
+ return Composed(self._obj + [other])
+ else:
+ return NotImplemented
+
+ def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
+ """
+ Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
+
+ The `!joiner` must be a `SQL` or a string which will be interpreted as
+ an `SQL`.
+
+ Example::
+
+ >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
+ >>> print(fields.join(', ').as_string(conn))
+ "foo", "bar"
+
+ """
+ if isinstance(joiner, str):
+ joiner = SQL(joiner)
+ elif not isinstance(joiner, SQL):
+ raise TypeError(
+ "Composed.join() argument must be strings or SQL,"
+ f" got {joiner!r} instead"
+ )
+
+ return joiner.join(self._obj)
+
+
+class SQL(Composable):
+ """
+ A `Composable` representing a snippet of SQL statement.
+
+ `!SQL` exposes `join()` and `format()` methods useful to create a template
+ where to merge variable parts of a query (for instance field or table
+ names).
+
+ The `!obj` string doesn't undergo any form of escaping, so it is not
+ suitable to represent variable identifiers or values: you should only use
+ it to pass constant strings representing templates or snippets of SQL
+ statements; use other objects such as `Identifier` or `Literal` to
+ represent variable parts.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {0} FROM {1}").format(
+ ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
+ ... sql.Identifier('table'))
+ >>> print(query.as_string(conn))
+ SELECT "foo", "bar" FROM "table"
+ """
+
+ _obj: LiteralString
+ _formatter = string.Formatter()
+
+ def __init__(self, obj: LiteralString):
+ super().__init__(obj)
+ if not isinstance(obj, str):
+ raise TypeError(f"SQL values must be strings, got {obj!r} instead")
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ return self._obj
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ enc = "utf-8"
+ if context:
+ enc = conn_encoding(context.connection)
+ return self._obj.encode(enc)
+
+ def format(self, *args: Any, **kwargs: Any) -> Composed:
+ """
+ Merge `Composable` objects into a template.
+
+ :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
+ auto-numbered (``{}``) placeholders
+ :param kwargs: parameters to replace to named (``{name}``) placeholders
+ :return: the union of the `!SQL` string with placeholders replaced
+ :rtype: `Composed`
+
+ The method is similar to the Python `str.format()` method: the string
+ template supports auto-numbered (``{}``), numbered (``{0}``,
+ ``{1}``...), and named placeholders (``{name}``), with positional
+ arguments replacing the numbered placeholders and keywords replacing
+ the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
+ are not supported.
+
+ If a `!Composable` objects is passed to the template it will be merged
+ according to its `as_string()` method. If any other Python object is
+ passed, it will be wrapped in a `Literal` object and so escaped
+ according to SQL rules.
+
+ Example::
+
+ >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
+ ... .format(sql.Identifier('people'), sql.Identifier('id'))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE "id" = %s
+
+ >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
+ ... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE name = 'O''Rourke'
+
+ """
+ rv: List[Composable] = []
+ autonum: Optional[int] = 0
+ # TODO: this is probably not the right way to whitelist pre
+ # pyre complains. Will wait for mypy to complain too to fix.
+ pre: LiteralString
+ for pre, name, spec, conv in self._formatter.parse(self._obj):
+ if spec:
+ raise ValueError("no format specification supported by SQL")
+ if conv:
+ raise ValueError("no format conversion supported by SQL")
+ if pre:
+ rv.append(SQL(pre))
+
+ if name is None:
+ continue
+
+ if name.isdigit():
+ if autonum:
+ raise ValueError(
+ "cannot switch from automatic field numbering to manual"
+ )
+ rv.append(args[int(name)])
+ autonum = None
+
+ elif not name:
+ if autonum is None:
+ raise ValueError(
+ "cannot switch from manual field numbering to automatic"
+ )
+ rv.append(args[autonum])
+ autonum += 1
+
+ else:
+ rv.append(kwargs[name])
+
+ return Composed(rv)
+
+ def join(self, seq: Iterable[Composable]) -> Composed:
+ """
+ Join a sequence of `Composable`.
+
+ :param seq: the elements to join.
+ :type seq: iterable of `!Composable`
+
+ Use the `!SQL` object's string to separate the elements in `!seq`.
+ Note that `Composed` objects are iterable too, so they can be used as
+ argument for this method.
+
+ Example::
+
+ >>> snip = sql.SQL(', ').join(
+ ... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
+ >>> print(snip.as_string(conn))
+ "foo", "bar", "baz"
+ """
+ rv = []
+ it = iter(seq)
+ try:
+ rv.append(next(it))
+ except StopIteration:
+ pass
+ else:
+ for i in it:
+ rv.append(self)
+ rv.append(i)
+
+ return Composed(rv)
+
+
+class Identifier(Composable):
+ """
+ A `Composable` representing an SQL identifier or a dot-separated sequence.
+
+ Identifiers usually represent names of database objects, such as tables or
+ fields. PostgreSQL identifiers follow `different rules`__ than SQL string
+ literals for escaping (e.g. they use double quotes instead of single).
+
+ .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
+ SQL-SYNTAX-IDENTIFIERS
+
+ Example::
+
+ >>> t1 = sql.Identifier("foo")
+ >>> t2 = sql.Identifier("ba'r")
+ >>> t3 = sql.Identifier('ba"z')
+ >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
+ "foo", "ba'r", "ba""z"
+
+ Multiple strings can be passed to the object to represent a qualified name,
+ i.e. a dot-separated sequence of identifiers.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {} FROM {}").format(
+ ... sql.Identifier("table", "field"),
+ ... sql.Identifier("schema", "table"))
+ >>> print(query.as_string(conn))
+ SELECT "table"."field" FROM "schema"."table"
+
+ """
+
+ _obj: Sequence[str]
+
+ def __init__(self, *strings: str):
+ # init super() now to make the __repr__ not explode in case of error
+ super().__init__(strings)
+
+ if not strings:
+ raise TypeError("Identifier cannot be empty")
+
+ for s in strings:
+ if not isinstance(s, str):
+ raise TypeError(
+ f"SQL identifier parts must be strings, got {s!r} instead"
+ )
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ if not conn:
+ raise ValueError("a connection is necessary for Identifier")
+ esc = Escaping(conn.pgconn)
+ enc = conn_encoding(conn)
+ escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ return b".".join(escs)
+
+
+class Literal(Composable):
+ """
+ A `Composable` representing an SQL value to include in a query.
+
+ Usually you will want to include placeholders in the query and pass values
+ as `~cursor.execute()` arguments. If however you really really need to
+ include a literal value in the query you can use this object.
+
+ The string returned by `!as_string()` follows the normal :ref:`adaptation
+ rules <types-adaptation>` for Python objects.
+
+ Example::
+
+ >>> s1 = sql.Literal("fo'o")
+ >>> s2 = sql.Literal(42)
+ >>> s3 = sql.Literal(date(2000, 1, 1))
+ >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
+ 'fo''o', 42, '2000-01-01'::date
+
+ """
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ tx = Transformer.from_context(context)
+ return tx.as_literal(self._obj)
+
+
+class Placeholder(Composable):
+ """A `Composable` representing a placeholder for query parameters.
+
+ If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
+ ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
+ ``%b``).
+
+ The object is useful to generate SQL queries with a variable number of
+ arguments.
+
+ Examples::
+
+ >>> names = ['foo', 'bar', 'baz']
+
+ >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(sql.Placeholder() * len(names)))
+ >>> print(q1.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
+
+ >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(map(sql.Placeholder, names)))
+ >>> print(q2.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
+
+ """
+
+ def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
+ super().__init__(name)
+ if not isinstance(name, str):
+ raise TypeError(f"expected string as name, got {name!r}")
+
+ if ")" in name:
+ raise ValueError(f"invalid name: {name!r}")
+
+ if type(format) is str:
+ format = PyFormat(format)
+ if not isinstance(format, PyFormat):
+ raise TypeError(
+ f"expected PyFormat as format, got {type(format).__name__!r}"
+ )
+
+ self._format: PyFormat = format
+
+ def __repr__(self) -> str:
+ parts = []
+ if self._obj:
+ parts.append(repr(self._obj))
+ if self._format is not PyFormat.AUTO:
+ parts.append(f"format={self._format.name}")
+
+ return f"{self.__class__.__name__}({', '.join(parts)})"
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ code = self._format.value
+ return f"%({self._obj}){code}" if self._obj else f"%{code}"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ return self.as_string(context).encode(enc)
+
+
+# Literals
+NULL = SQL("NULL")
+DEFAULT = SQL("DEFAULT")
diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py
new file mode 100644
index 0000000..e13486e
--- /dev/null
+++ b/psycopg/psycopg/transaction.py
@@ -0,0 +1,290 @@
+"""
+Transaction context managers returned by Connection.transaction()
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from types import TracebackType
+from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, PQGen
+
+if TYPE_CHECKING:
+ from typing import Any
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+IDLE = pq.TransactionStatus.IDLE
+
+OK = pq.ConnStatus.OK
+
+logger = logging.getLogger(__name__)
+
+
+class Rollback(Exception):
+ """
+ Exit the current `Transaction` context immediately and rollback any changes
+ made within this context.
+
+ If a transaction context is specified in the constructor, rollback
+ enclosing transactions contexts up to and including the one specified.
+ """
+
+ __module__ = "psycopg"
+
+ def __init__(
+ self,
+ transaction: Union["Transaction", "AsyncTransaction", None] = None,
+ ):
+ self.transaction = transaction
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__qualname__}({self.transaction!r})"
+
+
+class OutOfOrderTransactionNesting(e.ProgrammingError):
+ """Out-of-order transaction nesting detected"""
+
+
+class BaseTransaction(Generic[ConnectionType]):
+ def __init__(
+ self,
+ connection: ConnectionType,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ):
+ self._conn = connection
+ self.pgconn = self._conn.pgconn
+ self._savepoint_name = savepoint_name or ""
+ self.force_rollback = force_rollback
+ self._entered = self._exited = False
+ self._outer_transaction = False
+ self._stack_index = -1
+
+ @property
+ def savepoint_name(self) -> Optional[str]:
+ """
+ The name of the savepoint; `!None` if handling the main transaction.
+ """
+ # Yes, it may change on __enter__. No, I don't care, because the
+ # un-entered state is outside the public interface.
+ return self._savepoint_name
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ if not self._entered:
+ status = "inactive"
+ elif not self._exited:
+ status = "active"
+ else:
+ status = "terminated"
+
+ sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
+ return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
+
+ def _enter_gen(self) -> PQGen[None]:
+ if self._entered:
+ raise TypeError("transaction blocks can be used only once")
+ self._entered = True
+
+ self._push_savepoint()
+ for command in self._get_enter_commands():
+ yield from self._conn._exec_command(command)
+
+ def _exit_gen(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> PQGen[bool]:
+ if not exc_val and not self.force_rollback:
+ yield from self._commit_gen()
+ return False
+ else:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ return (yield from self._rollback_gen(exc_val))
+ except OutOfOrderTransactionNesting:
+ # Clobber an exception happened in the block with the exception
+ # caused by out-of-order transaction detected, so make the
+ # behaviour consistent with _commit_gen and to make sure the
+ # user fixes this condition, which is unrelated from
+ # operational error that might arise in the block.
+ raise
+ except Exception as exc2:
+ logger.warning("error ignored in rollback of %s: %s", self, exc2)
+ return False
+
+ def _commit_gen(self) -> PQGen[None]:
+ ex = self._pop_savepoint("commit")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_commit_commands():
+ yield from self._conn._exec_command(command)
+
+ def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
+ if isinstance(exc_val, Rollback):
+ logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
+
+ ex = self._pop_savepoint("rollback")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_rollback_commands():
+ yield from self._conn._exec_command(command)
+
+ if isinstance(exc_val, Rollback):
+ if not exc_val.transaction or exc_val.transaction is self:
+ return True # Swallow the exception
+
+ return False
+
+ def _get_enter_commands(self) -> Iterator[bytes]:
+ if self._outer_transaction:
+ yield self._conn._get_tx_start_command()
+
+ if self._savepoint_name:
+ yield (
+ sql.SQL("SAVEPOINT {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ def _get_commit_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("RELEASE {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"COMMIT"
+
+ def _get_rollback_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("ROLLBACK TO {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+ yield (
+ sql.SQL("RELEASE {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"ROLLBACK"
+
+ # Also clear the prepared statements cache.
+ if self._conn._prepared.clear():
+ yield from self._conn._prepared.get_maintenance_commands()
+
+ def _push_savepoint(self) -> None:
+ """
+ Push the transaction on the connection transactions stack.
+
+ Also set the internal state of the object and verify consistency.
+ """
+ self._outer_transaction = self.pgconn.transaction_status == IDLE
+ if self._outer_transaction:
+ # outer transaction: if no name it's only a begin, else
+ # there will be an additional savepoint
+ assert not self._conn._num_transactions
+ else:
+ # inner transaction: it always has a name
+ if not self._savepoint_name:
+ self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
+
+ self._stack_index = self._conn._num_transactions
+ self._conn._num_transactions += 1
+
+ def _pop_savepoint(self, action: str) -> Optional[Exception]:
+ """
+ Pop the transaction from the connection transactions stack.
+
+ Also verify the state consistency.
+ """
+ self._conn._num_transactions -= 1
+ if self._conn._num_transactions == self._stack_index:
+ return None
+
+ return OutOfOrderTransactionNesting(
+ f"transaction {action} at the wrong nesting level: {self}"
+ )
+
+
+class Transaction(BaseTransaction["Connection[Any]"]):
+ """
+ Returned by `Connection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="Transaction")
+
+ @property
+ def connection(self) -> "Connection[Any]":
+ """The connection the object is managing."""
+ return self._conn
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ with self._conn.lock:
+ return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
+
+
+class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
+ """
+ Returned by `AsyncConnection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="AsyncTransaction")
+
+ @property
+ def connection(self) -> "AsyncConnection[Any]":
+ return self._conn
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ async with self._conn.lock:
+ return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py
new file mode 100644
index 0000000..bdddf05
--- /dev/null
+++ b/psycopg/psycopg/types/__init__.py
@@ -0,0 +1,11 @@
+"""
+psycopg types package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import _typeinfo
+
+# Exposed here
+TypeInfo = _typeinfo.TypeInfo
+TypesRegistry = _typeinfo.TypesRegistry
diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py
new file mode 100644
index 0000000..e35c5e7
--- /dev/null
+++ b/psycopg/psycopg/types/array.py
@@ -0,0 +1,464 @@
+"""
+Adapters for arrays
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
+
+from .. import pq
+from .. import errors as e
+from .. import postgres
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._compat import cache, prod
+from .._struct import pack_len, unpack_len
+from .._cmodule import _psycopg
+from ..postgres import TEXT_OID, INVALID_OID
+from .._typeinfo import TypeInfo
+
+_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
+_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
+_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
+_struct_dim = struct.Struct("!II") # dim, lower bound
+_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
+_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
+
+TEXT_ARRAY_OID = postgres.types["text"].array_oid
+
+PY_TEXT = PyFormat.TEXT
+PQ_BINARY = pq.Format.BINARY
+
+
+class BaseListDumper(RecursiveDumper):
+ element_oid = 0
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ if cls is NoneType:
+ cls = list
+
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ if self.element_oid and context:
+ sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
+ self.sub_dumper = sdclass(NoneType, context)
+
+ def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ items = list(self._flatiter(L, set()))
+ types = {type(item): item for item in items}
+ if not types:
+ return None
+
+ if len(types) == 1:
+ t, v = types.popitem()
+ else:
+ # More than one type in the list. It might be still good, as long
+ # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
+ dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
+ oids = set(d.oid for d in dumpers)
+ if len(oids) == 1:
+ t, v = types.popitem()
+ else:
+ raise e.DataError(
+ "cannot dump lists of mixed types;"
+ f" got: {', '.join(sorted(t.__name__ for t in types))}"
+ )
+
+ # Checking for precise type. If the type is a subclass (e.g. Int4)
+ # we assume the user knows what type they are passing.
+ if t is not int:
+ return v
+
+ # If we got an int, let's see what is the biggest one in order to
+ # choose the smallest OID and allow Postgres to do the right cast.
+ imax: int = max(items)
+ imin: int = min(items)
+ if imin >= 0:
+ return imax
+ else:
+ return max(imax, -imin - 1)
+
+ def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
+ if id(L) in seen:
+ raise e.DataError("cannot dump a recursive list")
+
+ seen.add(id(L))
+
+ for item in L:
+ if type(item) is list:
+ yield from self._flatiter(item, seen)
+ elif item is not None:
+ yield item
+
+ return None
+
+ def _get_base_type_info(self, base_oid: int) -> TypeInfo:
+ """
+ Return info about the base type.
+
+ Return text info as fallback.
+ """
+ if base_oid:
+ info = self._tx.adapters.types.get(base_oid)
+ if info:
+ return info
+
+ return self._tx.adapters.types["text"]
+
+
+class ListDumper(BaseListDumper):
+
+ delimiter = b","
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return self.cls
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ # Empty lists can only be dumped as text if the type is unknown.
+ return self
+
+ sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+
+ # We consider an array of unknowns as unknown, so we can dump empty
+ # lists or lists containing only None elements.
+ if sd.oid != INVALID_OID:
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+ dumper.delimiter = info.delimiter.encode()
+ else:
+ dumper.oid = INVALID_OID
+
+ return dumper
+
+ # Double quotes and backslashes embedded in element values will be
+ # backslash-escaped.
+ _re_esc = re.compile(rb'(["\\])')
+
+ def dump(self, obj: List[Any]) -> bytes:
+ tokens: List[Buffer] = []
+ needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
+
+ def dump_list(obj: List[Any]) -> None:
+ if not obj:
+ tokens.append(b"{}")
+ return
+
+ tokens.append(b"{")
+ for item in obj:
+ if isinstance(item, list):
+ dump_list(item)
+ elif item is not None:
+ ad = self._dump_item(item)
+ if needs_quotes(ad):
+ if not isinstance(ad, bytes):
+ ad = bytes(ad)
+ ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
+ tokens.append(ad)
+ else:
+ tokens.append(b"NULL")
+
+ tokens.append(self.delimiter)
+
+ tokens[-1] = b"}"
+
+ dump_list(obj)
+
+ return b"".join(tokens)
+
+ def _dump_item(self, item: Any) -> Buffer:
+ if self.sub_dumper:
+ return self.sub_dumper.dump(item)
+ else:
+ return self._tx.get_dumper(item, PY_TEXT).dump(item)
+
+
+@cache
+def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """Return a regexp to recognise when a value needs quotes
+
+ from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+
+ The array output routine will put double quotes around element values if
+ they are empty strings, contain curly braces, delimiter characters,
+ double quotes, backslashes, or white space, or match the word NULL.
+ """
+ return re.compile(
+ rb"""(?xi)
+ ^$ # the empty string
+ | ["{}%s\\\s] # or a char to escape
+ | ^null$ # or the word NULL
+ """
+ % delimiter
+ )
+
+
+class ListBinaryDumper(BaseListDumper):
+
+ format = pq.Format.BINARY
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return (self.cls,)
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return ListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+
+ return dumper
+
+ def dump(self, obj: List[Any]) -> bytes:
+ # Postgres won't take unknown for element oid: fall back on text
+ sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
+
+ if not obj:
+ return _pack_head(0, 0, sub_oid)
+
+ data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
+ dims: List[int] = []
+ hasnull = 0
+
+ def calc_dims(L: List[Any]) -> None:
+ if isinstance(L, self.cls):
+ if not L:
+ raise e.DataError("lists cannot contain empty lists")
+ dims.append(len(L))
+ calc_dims(L[0])
+
+ calc_dims(obj)
+
+ def dump_list(L: List[Any], dim: int) -> None:
+ nonlocal hasnull
+ if len(L) != dims[dim]:
+ raise e.DataError("nested lists have inconsistent lengths")
+
+ if dim == len(dims) - 1:
+ for item in L:
+ if item is not None:
+ # If we get here, the sub_dumper must have been set
+ ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
+ data.append(pack_len(len(ad)))
+ data.append(ad)
+ else:
+ hasnull = 1
+ data.append(b"\xff\xff\xff\xff")
+ else:
+ for item in L:
+ if not isinstance(item, self.cls):
+ raise e.DataError("nested lists have inconsistent depths")
+ dump_list(item, dim + 1) # type: ignore
+
+ dump_list(obj, 0)
+
+ data[0] = _pack_head(len(dims), hasnull, sub_oid)
+ data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
+ return b"".join(data)
+
+
+class ArrayLoader(RecursiveLoader):
+
+ delimiter = b","
+ base_oid: int
+
+ def load(self, data: Buffer) -> List[Any]:
+ loader = self._tx.get_loader(self.base_oid, self.format)
+ return _load_text(data, loader, self.delimiter)
+
+
+class ArrayBinaryLoader(RecursiveLoader):
+
+ format = pq.Format.BINARY
+
+ def load(self, data: Buffer) -> List[Any]:
+ return _load_binary(data, self._tx)
+
+
+def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ if not info.array_oid:
+ raise ValueError(f"the type info {info} doesn't describe an array")
+
+ base: Type[Any]
+ adapters = context.adapters if context else postgres.adapters
+
+ base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "base_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ loader = type(name, (base,), attribs)
+ adapters.register_loader(info.array_oid, loader)
+
+ loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
+ adapters.register_loader(info.array_oid, loader)
+
+ base = ListDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+ base = ListBinaryDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ # The text dumper is more flexible as it can handle lists of mixed type,
+ # so register it later.
+ context.adapters.register_dumper(list, ListBinaryDumper)
+ context.adapters.register_dumper(list, ListDumper)
+
+
+def register_all_arrays(context: AdaptContext) -> None:
+ """
+ Associate the array oid of all the types in Loader.globals.
+
+ This function is designed to be called once at import time, after having
+ registered all the base loaders.
+ """
+ for t in context.adapters.types:
+ if t.array_oid:
+ t.register(context)
+
+
+def _load_text(
+ data: Buffer,
+ loader: Loader,
+ delimiter: bytes = b",",
+ __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
+) -> List[Any]:
+ rv = None
+ stack: List[Any] = []
+ a: List[Any] = []
+ rv = a
+ load = loader.load
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if data and data[0] == b"["[0]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ idx = data.find(b"=")
+ if idx == -1:
+ raise e.DataError("malformed array: no '=' after dimension information")
+ data = data[idx + 1 :]
+
+ re_parse = _get_array_parse_regexp(delimiter)
+ for m in re_parse.finditer(data):
+ t = m.group(1)
+ if t == b"{":
+ if stack:
+ stack[-1].append(a)
+ stack.append(a)
+ a = []
+
+ elif t == b"}":
+ if not stack:
+ raise e.DataError("malformed array: unexpected '}'")
+ rv = stack.pop()
+
+ else:
+ if not stack:
+ wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
+ raise e.DataError(f"malformed array: unexpected '{wat}'")
+ if t == b"NULL":
+ v = None
+ else:
+ if t.startswith(b'"'):
+ t = __re_unescape.sub(rb"\1", t[1:-1])
+ v = load(t)
+
+ stack[-1].append(v)
+
+ assert rv is not None
+ return rv
+
+
+@cache
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """
+ Return a regexp to tokenize an array representation into item and brackets
+ """
+ return re.compile(
+ rb"""(?xi)
+ ( [{}] # open or closed bracket
+ | " (?: [^"\\] | \\. )* " # or a quoted string
+ | [^"{}%s\\]+ # or an unquoted non-empty string
+ ) ,?
+ """
+ % delimiter
+ )
+
+
+def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
+ ndims, hasnull, oid = _unpack_head(data)
+ load = tx.get_loader(oid, PQ_BINARY).load
+
+ if not ndims:
+ return []
+
+ p = 12 + 8 * ndims
+ dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
+ nelems = prod(dims)
+
+ out: List[Any] = [None] * nelems
+ for i in range(nelems):
+ size = unpack_len(data, p)[0]
+ p += 4
+ if size == -1:
+ continue
+ out[i] = load(data[p : p + size])
+ p += size
+
+ # fon ndims > 1 we have to aggregate the array into sub-arrays
+ for dim in dims[-1:0:-1]:
+ out = [out[i : i + dim] for i in range(0, len(out), dim)]
+
+ return out
diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py
new file mode 100644
index 0000000..db7e181
--- /dev/null
+++ b/psycopg/psycopg/types/bool.py
@@ -0,0 +1,51 @@
+"""
+Adapters for booleans.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+
+class BoolDumper(Dumper):
+
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"t" if obj else b"f"
+
+ def quote(self, obj: bool) -> bytes:
+ return b"true" if obj else b"false"
+
+
+class BoolBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"\x01" if obj else b"\x00"
+
+
+class BoolLoader(Loader):
+ def load(self, data: Buffer) -> bool:
+ return data == b"t"
+
+
+class BoolBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> bool:
+ return data != b"\x00"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(bool, BoolDumper)
+ adapters.register_dumper(bool, BoolBinaryDumper)
+ adapters.register_loader("bool", BoolLoader)
+ adapters.register_loader("bool", BoolBinaryLoader)
diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py
new file mode 100644
index 0000000..1c609c3
--- /dev/null
+++ b/psycopg/psycopg/types/composite.py
@@ -0,0 +1,290 @@
+"""
+Support for composite types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from collections import namedtuple
+from typing import Any, Callable, cast, Iterator, List, Optional
+from typing import Sequence, Tuple, Type
+
+from .. import pq
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
+from .._struct import pack_len, unpack_len
+from ..postgres import TEXT_OID
+from .._typeinfo import CompositeInfo as CompositeInfo # exported here
+from .._encodings import _as_python_identifier
+
+_struct_oidlen = struct.Struct("!Ii")
+_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
+_unpack_oidlen = cast(
+ Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
+)
+
+
+class SequenceDumper(RecursiveDumper):
+ def _dump_sequence(
+ self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
+ ) -> bytes:
+ if not obj:
+ return start + end
+
+ parts: List[Buffer] = [start]
+
+ for item in obj:
+ if item is None:
+ parts.append(sep)
+ continue
+
+ dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ ad = dumper.dump(item)
+ if not ad:
+ ad = b'""'
+ elif self._re_needs_quotes.search(ad):
+ ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"'
+
+ parts.append(ad)
+ parts.append(sep)
+
+ parts[-1] = end
+
+ return b"".join(parts)
+
+ _re_needs_quotes = re.compile(rb'[",\\\s()]')
+ _re_esc = re.compile(rb"([\\\"])")
+
+
+class TupleDumper(SequenceDumper):
+
+ # Should be this, but it doesn't work
+ # oid = postgres_types["record"].oid
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytes:
+ return self._dump_sequence(obj, b"(", b")", b",")
+
+
+class TupleBinaryDumper(RecursiveDumper):
+
+ format = pq.Format.BINARY
+
+ # Subclasses must set an info
+ info: CompositeInfo
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ nfields = len(self.info.field_types)
+ self._tx.set_dumper_types(self.info.field_types, self.format)
+ self._formats = (PyFormat.from_pq(self.format),) * nfields
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytearray:
+ out = bytearray(pack_len(len(obj)))
+ adapted = self._tx.dump_sequence(obj, self._formats)
+ for i in range(len(obj)):
+ b = adapted[i]
+ oid = self.info.field_types[i]
+ if b is not None:
+ out += _pack_oidlen(oid, len(b))
+ out += b
+ else:
+ out += _pack_oidlen(oid, -1)
+
+ return out
+
+
+class BaseCompositeLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
+ """
+ Split a non-empty representation of a composite type into components.
+
+ Terminators shouldn't be used in `!data` (so that both record and range
+ representations can be parsed).
+ """
+ for m in self._re_tokenize.finditer(data):
+ if m.group(1):
+ yield None
+ elif m.group(2) is not None:
+ yield self._re_undouble.sub(rb"\1", m.group(2))
+ else:
+ yield m.group(3)
+
+ # If the final group ended in `,` there is a final NULL in the record
+ # that the regexp couldn't parse.
+ if m and m.group().endswith(b","):
+ yield None
+
+ _re_tokenize = re.compile(
+ rb"""(?x)
+ (,) # an empty token, representing NULL
+ | " ((?: [^"] | "")*) " ,? # or a quoted string
+ | ([^",)]+) ,? # or an unquoted string
+ """
+ )
+
+ _re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RecordLoader(BaseCompositeLoader):
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if data == b"()":
+ return ()
+
+ cast = self._tx.get_loader(TEXT_OID, self.format).load
+ return tuple(
+ cast(token) if token is not None else None
+ for token in self._parse_record(data[1:-1])
+ )
+
+
+class RecordBinaryLoader(Loader):
+ format = pq.Format.BINARY
+ _types_set = False
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ return self._tx.load_sequence(
+ tuple(
+ data[offset : offset + length] if length != -1 else None
+ for _, offset, length in self._walk_record(data)
+ )
+ )
+
+ def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
+ """
+ Yield a sequence of (oid, offset, length) for the content of the record
+ """
+ nfields = unpack_len(data, 0)[0]
+ i = 4
+ for _ in range(nfields):
+ oid, length = _unpack_oidlen(data, i)
+ yield oid, i + 8, length
+ i += (8 + length) if length > 0 else 8
+
+ def _config_types(self, data: Buffer) -> None:
+ oids = [r[0] for r in self._walk_record(data)]
+ self._tx.set_loader_types(oids, self.format)
+
+
+class CompositeLoader(RecordLoader):
+
+ factory: Callable[..., Any]
+ fields_types: List[int]
+ _types_set = False
+
+ def load(self, data: Buffer) -> Any:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ if data == b"()":
+ return type(self).factory()
+
+ return type(self).factory(
+ *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
+ )
+
+ def _config_types(self, data: Buffer) -> None:
+ self._tx.set_loader_types(self.fields_types, self.format)
+
+
+class CompositeBinaryLoader(RecordBinaryLoader):
+
+ format = pq.Format.BINARY
+ factory: Callable[..., Any]
+
+ def load(self, data: Buffer) -> Any:
+ r = super().load(data)
+ return type(self).factory(*r)
+
+
+def register_composite(
+ info: CompositeInfo,
+ context: Optional[AdaptContext] = None,
+ factory: Optional[Callable[..., Any]] = None,
+) -> None:
+ """Register the adapters to load and dump a composite type.
+
+ :param info: The object with the information about the composite to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param factory: Callable to convert the sequence of attributes read from
+ the composite into a Python object.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested composite available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ if not factory:
+ factory = namedtuple( # type: ignore
+ _as_python_identifier(info.name),
+ [_as_python_identifier(n) for n in info.field_names],
+ )
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[BaseCompositeLoader] = type(
+ f"{info.name.title()}Loader",
+ (CompositeLoader,),
+ {
+ "factory": factory,
+ "fields_types": info.field_types,
+ },
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ loader = type(
+ f"{info.name.title()}BinaryLoader",
+ (CompositeBinaryLoader,),
+ {"factory": factory},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # If the factory is a type, create and register dumpers for it
+ if isinstance(factory, type):
+ dumper = type(
+ f"{info.name.title()}BinaryDumper",
+ (TupleBinaryDumper,),
+ {"oid": info.oid, "info": info},
+ )
+ adapters.register_dumper(factory, dumper)
+
+ # Default to the text dumper because it is more flexible
+ dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid})
+ adapters.register_dumper(factory, dumper)
+
+ info.python_type = factory
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(tuple, TupleDumper)
+ adapters.register_loader("record", RecordLoader)
+ adapters.register_loader("record", RecordBinaryLoader)
diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py
new file mode 100644
index 0000000..f0dfe83
--- /dev/null
+++ b/psycopg/psycopg/types/datetime.py
@@ -0,0 +1,754 @@
+"""
+Adapters for date/time types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from datetime import date, datetime, time, timedelta, timezone
+from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from .._tz import get_tzinfo
+from ..abc import AdaptContext, DumperKey
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from ..errors import InterfaceError, DataError
+from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
+
+if TYPE_CHECKING:
+ from ..connection import BaseConnection
+
+_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
+_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
+_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
+
+_struct_interval = struct.Struct("!qii") # microseconds, days, months
+_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
+_unpack_interval = cast(
+ Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
+)
+
+utc = timezone.utc
+_pg_date_epoch_days = date(2000, 1, 1).toordinal()
+_pg_datetime_epoch = datetime(2000, 1, 1)
+_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
+_py_date_min_days = date.min.toordinal()
+
+
+class DateDumper(Dumper):
+
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DateBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ days = obj.toordinal() - _pg_date_epoch_days
+ return pack_int4(days)
+
+
+class _BaseTimeDumper(Dumper):
+ def get_key(self, obj: time, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseTimeTextDumper(_BaseTimeDumper):
+ def dump(self, obj: time) -> bytes:
+ return str(obj).encode()
+
+
+class TimeDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["time"].oid
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+class TimeTzDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["timetz"].oid
+
+
+class TimeBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["time"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ return pack_int8(us)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzBinaryDumper(self.cls)
+
+
+class TimeTzBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timetz"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ off = obj.utcoffset()
+ assert off is not None
+ return _pack_timetz(us, -int(off.total_seconds()))
+
+
+class _BaseDatetimeDumper(Dumper):
+ def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
+ def dump(self, obj: datetime) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DatetimeDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamptz"].oid
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzDumper(self.cls)
+
+
+class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamp"].oid
+
+
+class DatetimeBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamptz"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetimetz_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzBinaryDumper(self.cls)
+
+
+class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamp"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetime_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+
+class TimedeltaDumper(Dumper):
+
+ oid = postgres.types["interval"].oid
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ if self.connection:
+ if (
+ self.connection.pgconn.parameter_status(b"IntervalStyle")
+ == b"sql_standard"
+ ):
+ setattr(self, "dump", self._dump_sql)
+
+ def dump(self, obj: timedelta) -> bytes:
+ # The comma is parsed ok by PostgreSQL but it's not documented
+ # and it seems brittle to rely on it. CRDB doesn't consume it well.
+ return str(obj).encode().replace(b",", b"")
+
+ def _dump_sql(self, obj: timedelta) -> bytes:
+ # sql_standard format needs explicit signs
+ # otherwise -1 day 1 sec will mean -1 sec
+ return b"%+d day %+d second %+d microsecond" % (
+ obj.days,
+ obj.seconds,
+ obj.microseconds,
+ )
+
+
+class TimedeltaBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["interval"].oid
+
+ def dump(self, obj: timedelta) -> bytes:
+ micros = 1_000_000 * obj.seconds + obj.microseconds
+ return _pack_interval(micros, obj.days, 0)
+
+
+class DateLoader(Loader):
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> date:
+ if self._order == self._ORDER_YMD:
+ ye = data[:4]
+ mo = data[5:7]
+ da = data[8:]
+ elif self._order == self._ORDER_DMY:
+ da = data[:2]
+ mo = data[3:5]
+ ye = data[6:]
+ else:
+ mo = data[:2]
+ da = data[3:5]
+ ye = data[6:]
+
+ try:
+ return date(int(ye), int(mo), int(da))
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ if s == "infinity" or (s and len(s.split()[0]) > 10):
+ raise DataError(f"date too large (after year 10K): {s!r}") from None
+ elif s == "-infinity" or "BC" in s:
+ raise DataError(f"date too small (before year 1): {s!r}") from None
+ else:
+ raise DataError(f"can't parse date {s!r}: {ex}") from None
+
+
+class DateBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> date:
+ days = unpack_int4(data)[0] + _pg_date_epoch_days
+ try:
+ return date.fromordinal(days)
+ except (ValueError, OverflowError):
+ if days < _py_date_min_days:
+ raise DataError("date too small (before year 1)") from None
+ else:
+ raise DataError("date too large (after year 10K)") from None
+
+
+class TimeLoader(Loader):
+
+ _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}")
+
+ ho, mi, se, fr = m.groups()
+
+ # Pad the fraction of second to get micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return time(int(ho), int(mi), int(se), us)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}: {e}") from None
+
+
+class TimeBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val = unpack_int8(data)[0]
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ try:
+ return time(h, m, s, us)
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimetzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}")
+
+ ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone
+ off = 60 * 60 * int(oh)
+ if om:
+ off += 60 * int(om)
+ if os:
+ off += int(os)
+ tz = timezone(timedelta(0, off if sgn == b"+" else -off))
+
+ try:
+ return time(int(ho), int(mi), int(se), us, tz)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}: {e}") from None
+
+
+class TimetzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val, off = _unpack_timetz(data)
+
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+
+ try:
+ return time(h, m, s, us, timezone(timedelta(seconds=-off)))
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimestampLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ $
+ """
+ )
+ _re_format_pg = re.compile(
+ rb"""(?ix)
+ ^
+ [a-z]+ [^a-z0-9] # DoW, separator
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ [^a-z0-9] (\d+) # Year
+ $
+ """
+ )
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+ _ORDER_PGDM = 3
+ _ORDER_PGMD = 4
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S"): # SQL
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ elif ds.startswith(b"P"): # Postgres
+ self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
+ self._re_format = self._re_format_pg
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ if self._order == self._ORDER_YMD:
+ ye, mo, da, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_DMY:
+ da, mo, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_MDY:
+ mo, da, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ else:
+ if self._order == self._ORDER_PGDM:
+ da, mo, ho, mi, se, fr, ye = m.groups()
+ else:
+ mo, da, ho, mi, se, fr, ye = m.groups()
+ try:
+ imo = _month_abbr[mo]
+ except KeyError:
+ s = mo.decode("utf8", "replace")
+ raise DataError(f"can't parse month: {s!r}") from None
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+
+class TimestampBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ return _pg_datetime_epoch + timedelta(microseconds=micros)
+ except OverflowError:
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class TimestamptzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ ds = _get_datestyle(self.connection)
+ if not ds.startswith(b"I"): # not ISO
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone offset
+ soff = 60 * 60 * int(oh)
+ if om:
+ soff += 60 * int(om)
+ if os:
+ soff += int(os)
+ tzoff = timedelta(0, soff if sgn == b"+" else -soff)
+
+ # The return value is a datetime with the timezone of the connection
+ # (in order to be consistent with the binary loader, which is the only
+ # thing it can return). So create a temporary datetime object, in utc,
+ # shift it by the offset parsed from the timestamp, and then move it to
+ # the connection timezone.
+ dt = None
+ ex: Exception
+ try:
+ dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
+ return (dt - tzoff).astimezone(self._timezone)
+ except OverflowError as e:
+ # If we have created the temporary 'dt' it means that we have a
+ # datetime close to max, the shift pushed it past max, overflowing.
+ # In this case return the datetime in a fixed offset timezone.
+ if dt is not None:
+ return dt.replace(tzinfo=timezone(tzoff))
+ else:
+ ex = e
+ except ValueError as e:
+ ex = e
+
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+ def _load_notimpl(self, data: Buffer) -> datetime:
+ s = bytes(data).decode("utf8", "replace")
+ ds = _get_datestyle(self.connection).decode("ascii")
+ raise NotImplementedError(
+ f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
+ )
+
+
+class TimestamptzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
+ return ts.astimezone(self._timezone)
+ except OverflowError:
+ # If we were asked about a timestamp which would overflow in UTC,
+ # but not in the desired timezone (e.g. datetime.max at Chicago
+ # timezone) we can still save the day by shifting the value by the
+ # timezone offset and then replacing the timezone.
+ if self._timezone:
+ utcoff = self._timezone.utcoffset(
+ datetime.min if micros < 0 else datetime.max
+ )
+ if utcoff:
+ usoff = 1_000_000 * int(utcoff.total_seconds())
+ try:
+ ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
+ except OverflowError:
+ pass # will raise downstream
+ else:
+ return ts.replace(tzinfo=self._timezone)
+
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class IntervalLoader(Loader):
+
+ _re_interval = re.compile(
+ rb"""
+ (?: ([-+]?\d+) \s+ years? \s* )? # Years
+ (?: ([-+]?\d+) \s+ mons? \s* )? # Months
+ (?: ([-+]?\d+) \s+ days? \s* )? # Days
+ (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time
+ )?
+ """,
+ re.VERBOSE,
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if self.connection:
+ ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
+ if ints != b"postgres":
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> timedelta:
+ m = self._re_interval.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}")
+
+ ye, mo, da, sgn, ho, mi, se = m.groups()
+ days = 0
+ seconds = 0.0
+
+ if ye:
+ days += 365 * int(ye)
+ if mo:
+ days += 30 * int(mo)
+ if da:
+ days += int(da)
+
+ if ho:
+ seconds = 3600 * int(ho) + 60 * int(mi) + float(se)
+ if sgn == b"-":
+ seconds = -seconds
+
+ try:
+ return timedelta(days=days, seconds=seconds)
+ except OverflowError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}: {e}") from None
+
+ def _load_notimpl(self, data: Buffer) -> timedelta:
+ s = bytes(data).decode("utf8", "replace")
+ ints = (
+ self.connection
+ and self.connection.pgconn.parameter_status(b"IntervalStyle")
+ or b"unknown"
+ ).decode("utf8", "replace")
+ raise NotImplementedError(
+ f"can't parse interval with IntervalStyle {ints}: {s!r}"
+ )
+
+
+class IntervalBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> timedelta:
+ micros, days, months = _unpack_interval(data)
+ if months > 0:
+ years, months = divmod(months, 12)
+ days = days + 30 * months + 365 * years
+ elif months < 0:
+ years, months = divmod(-months, 12)
+ days = days - 30 * months - 365 * years
+
+ try:
+ return timedelta(days=days, microseconds=micros)
+ except OverflowError as e:
+ raise DataError(f"can't parse interval: {e}") from None
+
+
+def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
+ if conn:
+ ds = conn.pgconn.parameter_status(b"DateStyle")
+ if ds:
+ return ds
+
+ return b"ISO, DMY"
+
+
+def _get_timestamp_load_error(
+ conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
+) -> Exception:
+ s = bytes(data).decode("utf8", "replace")
+
+ def is_overflow(s: str) -> bool:
+ if not s:
+ return False
+
+ ds = _get_datestyle(conn)
+ if not ds.startswith(b"P"): # Postgres
+ return len(s.split()[0]) > 10 # date is first token
+ else:
+ return len(s.split()[-1]) > 4 # year is last token
+
+ if s == "-infinity" or s.endswith("BC"):
+ return DataError("timestamp too small (before year 1): {s!r}")
+ elif s == "infinity" or is_overflow(s):
+ return DataError(f"timestamp too large (after year 10K): {s!r}")
+ else:
+ return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
+
+
+_month_abbr = {
+ n: i
+ for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
+}
+
+# Pad to get microseconds from a fraction of seconds
+_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1]
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("datetime.date", DateDumper)
+ adapters.register_dumper("datetime.date", DateBinaryDumper)
+
+ # first register dumpers for 'timetz' oid, then the proper ones on time type.
+ adapters.register_dumper("datetime.time", TimeTzDumper)
+ adapters.register_dumper("datetime.time", TimeTzBinaryDumper)
+ adapters.register_dumper("datetime.time", TimeDumper)
+ adapters.register_dumper("datetime.time", TimeBinaryDumper)
+
+ # first register dumpers for 'timestamp' oid, then the proper ones
+ # on the datetime type.
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper)
+
+ adapters.register_dumper("datetime.timedelta", TimedeltaDumper)
+ adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper)
+
+ adapters.register_loader("date", DateLoader)
+ adapters.register_loader("date", DateBinaryLoader)
+ adapters.register_loader("time", TimeLoader)
+ adapters.register_loader("time", TimeBinaryLoader)
+ adapters.register_loader("timetz", TimetzLoader)
+ adapters.register_loader("timetz", TimetzBinaryLoader)
+ adapters.register_loader("timestamp", TimestampLoader)
+ adapters.register_loader("timestamp", TimestampBinaryLoader)
+ adapters.register_loader("timestamptz", TimestamptzLoader)
+ adapters.register_loader("timestamptz", TimestamptzBinaryLoader)
+ adapters.register_loader("interval", IntervalLoader)
+ adapters.register_loader("interval", IntervalBinaryLoader)
diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py
new file mode 100644
index 0000000..d3c7387
--- /dev/null
+++ b/psycopg/psycopg/types/enum.py
@@ -0,0 +1,177 @@
+"""
+Adapters for the enum type.
+"""
+from enum import Enum
+from typing import Any, Dict, Generic, Optional, Mapping, Sequence
+from typing import Tuple, Type, TypeVar, Union, cast
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from .._encodings import conn_encoding
+from .._typeinfo import EnumInfo as EnumInfo # exported here
+
+E = TypeVar("E", bound=Enum)
+
+EnumDumpMap: TypeAlias = Dict[E, bytes]
+EnumLoadMap: TypeAlias = Dict[bytes, E]
+EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
+
+
+class _BaseEnumLoader(Loader, Generic[E]):
+ """
+ Loader for a specific Enum class
+ """
+
+ enum: Type[E]
+ _load_map: EnumLoadMap[E]
+
+ def load(self, data: Buffer) -> E:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ try:
+ return self._load_map[data]
+ except KeyError:
+ enc = conn_encoding(self.connection)
+ label = data.decode(enc, "replace")
+ raise e.DataError(
+ f"bad member for enum {self.enum.__qualname__}: {label!r}"
+ )
+
+
+class _BaseEnumDumper(Dumper, Generic[E]):
+ """
+ Dumper for a specific Enum class
+ """
+
+ enum: Type[E]
+ _dump_map: EnumDumpMap[E]
+
+ def dump(self, value: E) -> Buffer:
+ return self._dump_map[value]
+
+
+class EnumDumper(Dumper):
+ """
+ Dumper for a generic Enum class
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._encoding = conn_encoding(self.connection)
+
+ def dump(self, value: E) -> Buffer:
+ return value.name.encode(self._encoding)
+
+
+class EnumBinaryDumper(EnumDumper):
+ format = Format.BINARY
+
+
+def register_enum(
+ info: EnumInfo,
+ context: Optional[AdaptContext] = None,
+ enum: Optional[Type[E]] = None,
+ *,
+ mapping: EnumMapping[E] = None,
+) -> None:
+ """Register the adapters to load and dump a enum type.
+
+ :param info: The object with the information about the enum to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
+ a new enum will be generated and exposed as `EnumInfo.enum`.
+ :param mapping: Override the mapping between `!enum` members and `!info`
+ labels.
+ """
+
+ if not info:
+ raise TypeError("no info passed. Is the requested enum available?")
+
+ if enum is None:
+ enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__))
+
+ info.enum = enum
+ adapters = context.adapters if context else postgres.adapters
+ info.register(context)
+
+ load_map = _make_load_map(info, enum, mapping, context)
+ attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
+
+ name = f"{info.name.title()}Loader"
+ loader = type(name, (_BaseEnumLoader,), attribs)
+ adapters.register_loader(info.oid, loader)
+
+ name = f"{info.name.title()}BinaryLoader"
+ loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
+ adapters.register_loader(info.oid, loader)
+
+ dump_map = _make_dump_map(info, enum, mapping, context)
+ attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
+
+ name = f"{enum.__name__}Dumper"
+ dumper = type(name, (_BaseEnumDumper,), attribs)
+ adapters.register_dumper(info.enum, dumper)
+
+ name = f"{enum.__name__}BinaryDumper"
+ dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY})
+ adapters.register_dumper(info.enum, dumper)
+
+
+def _make_load_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumLoadMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumLoadMap[E] = {}
+ for label in info.labels:
+ try:
+ member = enum[label]
+ except KeyError:
+ # tolerate a missing enum, assuming it won't be used. If it is we
+ # will get a DataError on fetch.
+ pass
+ else:
+ rv[label.encode(enc)] = member
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[label.encode(enc)] = member
+
+ return rv
+
+
+def _make_dump_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumDumpMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumDumpMap[E] = {}
+ for member in enum:
+ rv[member] = member.name.encode(enc)
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[member] = label.encode(enc)
+
+ return rv
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, EnumBinaryDumper)
+ context.adapters.register_dumper(Enum, EnumDumper)
diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py
new file mode 100644
index 0000000..e1ab1d5
--- /dev/null
+++ b/psycopg/psycopg/types/hstore.py
@@ -0,0 +1,131 @@
+"""
+Dict to hstore adaptation
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+from typing import Dict, List, Optional
+from typing_extensions import TypeAlias
+
+from .. import errors as e
+from .. import postgres
+from ..abc import Buffer, AdaptContext
+from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..postgres import TEXT_OID
+from .._typeinfo import TypeInfo
+
+_re_escape = re.compile(r'(["\\])')
+_re_unescape = re.compile(r"\\(.)")
+
+_re_hstore = re.compile(
+ r"""
+ # hstore key:
+ # a string of normal or escaped chars
+ "((?: [^"\\] | \\. )*)"
+ \s*=>\s* # hstore value
+ (?:
+ NULL # the value can be null - not caught
+ # or a quoted string like the key
+ | "((?: [^"\\] | \\. )*)"
+ )
+ (?:\s*,\s*|$) # pairs separated by comma or end of string.
+""",
+ re.VERBOSE,
+)
+
+
+Hstore: TypeAlias = Dict[str, Optional[str]]
+
+
+class BaseHstoreDumper(RecursiveDumper):
+ def dump(self, obj: Hstore) -> Buffer:
+ if not obj:
+ return b""
+
+ tokens: List[str] = []
+
+ def add_token(s: str) -> None:
+ tokens.append('"')
+ tokens.append(_re_escape.sub(r"\\\1", s))
+ tokens.append('"')
+
+ for k, v in obj.items():
+
+ if not isinstance(k, str):
+ raise e.DataError("hstore keys can only be strings")
+ add_token(k)
+
+ tokens.append("=>")
+
+ if v is None:
+ tokens.append("NULL")
+ elif not isinstance(v, str):
+ raise e.DataError("hstore keys can only be strings")
+ else:
+ add_token(v)
+
+ tokens.append(",")
+
+ del tokens[-1]
+ data = "".join(tokens)
+ dumper = self._tx.get_dumper(data, PyFormat.TEXT)
+ return dumper.dump(data)
+
+
+class HstoreLoader(RecursiveLoader):
+ def load(self, data: Buffer) -> Hstore:
+ loader = self._tx.get_loader(TEXT_OID, self.format)
+ s: str = loader.load(data)
+
+ rv: Hstore = {}
+ start = 0
+ for m in _re_hstore.finditer(s):
+ if m is None or m.start() != start:
+ raise e.DataError(f"error parsing hstore pair at char {start}")
+ k = _re_unescape.sub(r"\1", m.group(1))
+ v = m.group(2)
+ if v is not None:
+ v = _re_unescape.sub(r"\1", v)
+
+ rv[k] = v
+ start = m.end()
+
+ if start < len(s):
+ raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
+
+ return rv
+
+
+def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump hstore.
+
+ :param info: The object with the information about the hstore type.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'hstore' extension loaded?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # Generate and register a customized text dumper
+ class HstoreDumper(BaseHstoreDumper):
+ oid = info.oid
+
+ adapters.register_dumper(dict, HstoreDumper)
+
+ # register the text loader on the oid
+ adapters.register_loader(info.oid, HstoreLoader)
diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py
new file mode 100644
index 0000000..a80e0e4
--- /dev/null
+++ b/psycopg/psycopg/types/json.py
@@ -0,0 +1,232 @@
+"""
+Adapers for JSON types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import json
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+
+from .. import abc
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
+from ..errors import DataError
+
+JsonDumpsFunction = Callable[[Any], str]
+JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
+
+
+def set_json_dumps(
+ dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON serialisation function to store JSON objects in the database.
+
+ :param dumps: The dump function to use.
+ :type dumps: `!Callable[[Any], str]`
+ :param context: Where to use the `!dumps` function. If not specified, use it
+ globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default dumping JSON uses the builtin `json.dumps`. You can override
+ it to use a different JSON library or to use customised arguments.
+
+ If the `Json` wrapper specified a `!dumps` function, use it in precedence
+ of the one set by this function.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonDumper._dumps = dumps
+ else:
+ adapters = context.adapters
+
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ (Json, PyFormat.BINARY),
+ (Json, PyFormat.TEXT),
+ (Jsonb, PyFormat.BINARY),
+ (Jsonb, PyFormat.TEXT),
+ ]
+ dumper: Type[_JsonDumper]
+ for wrapper, format in grid:
+ base = _get_current_dumper(adapters, wrapper, format)
+ name = base.__name__
+ if not base.__name__.startswith("Custom"):
+ name = f"Custom{name}"
+ dumper = type(name, (base,), {"_dumps": dumps})
+ adapters.register_dumper(wrapper, dumper)
+
+
+def set_json_loads(
+ loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON parsing function to fetch JSON objects from the database.
+
+ :param loads: The load function to use.
+ :type loads: `!Callable[[bytes], Any]`
+ :param context: Where to use the `!loads` function. If not specified, use
+ it globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default loading JSON uses the builtin `json.loads`. You can override
+ it to use a different JSON library or to use customised arguments.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonLoader._loads = loads
+ else:
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ ("json", JsonLoader),
+ ("json", JsonBinaryLoader),
+ ("jsonb", JsonbLoader),
+ ("jsonb", JsonbBinaryLoader),
+ ]
+ loader: Type[_JsonLoader]
+ for tname, base in grid:
+ loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads})
+ context.adapters.register_loader(tname, loader)
+
+
+class _JsonWrapper:
+ __slots__ = ("obj", "dumps")
+
+ def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
+ self.obj = obj
+ self.dumps = dumps
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class Json(_JsonWrapper):
+ __slots__ = ()
+
+
+class Jsonb(_JsonWrapper):
+ __slots__ = ()
+
+
+class _JsonDumper(Dumper):
+
+ # The globally used JSON dumps() function. It can be changed globally (by
+ # set_json_dumps) or by a subclass.
+ _dumps: JsonDumpsFunction = json.dumps
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self.dumps = self.__class__._dumps
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return dumps(obj.obj).encode()
+
+
+class JsonDumper(_JsonDumper):
+
+ oid = postgres.types["json"].oid
+
+
+class JsonBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["json"].oid
+
+
+class JsonbDumper(_JsonDumper):
+
+ oid = postgres.types["jsonb"].oid
+
+
+class JsonbBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["jsonb"].oid
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return b"\x01" + dumps(obj.obj).encode()
+
+
+class _JsonLoader(Loader):
+
+ # The globally used JSON loads() function. It can be changed globally (by
+ # set_json_loads) or by a subclass.
+ _loads: JsonLoadsFunction = json.loads
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self.loads = self.__class__._loads
+
+ def load(self, data: Buffer) -> Any:
+ # json.loads() cannot work on memoryview.
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+class JsonLoader(_JsonLoader):
+ pass
+
+
+class JsonbLoader(_JsonLoader):
+ pass
+
+
+class JsonBinaryLoader(_JsonLoader):
+ format = Format.BINARY
+
+
+class JsonbBinaryLoader(_JsonLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Any:
+ if data and data[0] != 1:
+ raise DataError("unknown jsonb binary format: {data[0]}")
+ data = data[1:]
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+def _get_current_dumper(
+ adapters: AdaptersMap, cls: type, format: PyFormat
+) -> Type[abc.Dumper]:
+ try:
+ return adapters.get_dumper(cls, format)
+ except e.ProgrammingError:
+ return _default_dumpers[cls, format]
+
+
+_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
+ (Json, PyFormat.BINARY): JsonBinaryDumper,
+ (Json, PyFormat.TEXT): JsonDumper,
+ (Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
+ (Jsonb, PyFormat.TEXT): JsonDumper,
+}
+
+
+def register_default_adapters(context: abc.AdaptContext) -> None:
+ adapters = context.adapters
+
+ # Currently json binary format is nothing different than text, maybe with
+ # an extra memcopy we can avoid.
+ adapters.register_dumper(Json, JsonBinaryDumper)
+ adapters.register_dumper(Json, JsonDumper)
+ adapters.register_dumper(Jsonb, JsonbBinaryDumper)
+ adapters.register_dumper(Jsonb, JsonbDumper)
+ adapters.register_loader("json", JsonLoader)
+ adapters.register_loader("jsonb", JsonbLoader)
+ adapters.register_loader("json", JsonBinaryLoader)
+ adapters.register_loader("jsonb", JsonbBinaryLoader)
diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py
new file mode 100644
index 0000000..3eaa7f1
--- /dev/null
+++ b/psycopg/psycopg/types/multirange.py
@@ -0,0 +1,514 @@
+"""
+Support for multirange types adaptation.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from decimal import Decimal
+from typing import Any, Generic, List, Iterable
+from typing import MutableSequence, Optional, Type, Union, overload
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
+
+from .range import Range, T, load_range_text, load_range_binary
+from .range import dump_range_text, dump_range_binary, fail_dump
+
+
+class Multirange(MutableSequence[Range[T]]):
+ """Python representation for a PostgreSQL multirange type.
+
+ :param items: Sequence of ranges to initialise the object.
+ """
+
+ def __init__(self, items: Iterable[Range[T]] = ()):
+ self._ranges: List[Range[T]] = list(map(self._check_type, items))
+
+ def _check_type(self, item: Any) -> Range[Any]:
+ if not isinstance(item, Range):
+ raise TypeError(
+ f"Multirange is a sequence of Range, got {type(item).__name__}"
+ )
+ return item
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._ranges!r})"
+
+ def __str__(self) -> str:
+ return f"{{{', '.join(map(str, self._ranges))}}}"
+
+ @overload
+ def __getitem__(self, index: int) -> Range[T]:
+ ...
+
+ @overload
+ def __getitem__(self, index: slice) -> "Multirange[T]":
+ ...
+
+ def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
+ if isinstance(index, int):
+ return self._ranges[index]
+ else:
+ return Multirange(self._ranges[index])
+
+ def __len__(self) -> int:
+ return len(self._ranges)
+
+ @overload
+ def __setitem__(self, index: int, value: Range[T]) -> None:
+ ...
+
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
+ ...
+
+ def __setitem__(
+ self,
+ index: Union[int, slice],
+ value: Union[Range[T], Iterable[Range[T]]],
+ ) -> None:
+ if isinstance(index, int):
+ self._check_type(value)
+ self._ranges[index] = self._check_type(value)
+ elif not isinstance(value, Iterable):
+ raise TypeError("can only assign an iterable")
+ else:
+ value = map(self._check_type, value)
+ self._ranges[index] = value
+
+ def __delitem__(self, index: Union[int, slice]) -> None:
+ del self._ranges[index]
+
+ def insert(self, index: int, value: Range[T]) -> None:
+ self._ranges.insert(index, self._check_type(value))
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return False
+ return self._ranges == other._ranges
+
+ # Order is arbitrary but consistent
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges < other._ranges
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges > other._ranges
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+
+# Subclasses to specify a specific subtype. Usually not needed
+
+
+class Int4Multirange(Multirange[int]):
+ pass
+
+
+class Int8Multirange(Multirange[int]):
+ pass
+
+
+class NumericMultirange(Multirange[Decimal]):
+ pass
+
+
+class DateMultirange(Multirange[date]):
+ pass
+
+
+class TimestampMultirange(Multirange[datetime]):
+ pass
+
+
+class TimestamptzMultirange(Multirange[datetime]):
+ pass
+
+
+class BaseMultirangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return MultirangeDumper(self.cls)
+
+ dumper: BaseMultirangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = MultirangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_multirange_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_multirange_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Multirange[Any]) -> Any:
+ """
+ Return a member representative of the multirange
+ """
+ for r in obj:
+ if r.lower is not None:
+ return r.lower
+ if r.upper is not None:
+ return r.upper
+ return None
+
+ def _get_multirange_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class MultirangeDumper(BaseMultirangeDumper):
+ """
+ Dumper for multirange types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ if not obj:
+ return b"{}"
+
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [b"{"]
+ for r in obj:
+ out.append(dump_range_text(r, dump))
+ out.append(b",")
+ out[-1] = b"}"
+ return b"".join(out)
+
+
+class MultirangeBinaryDumper(BaseMultirangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [pack_len(len(obj))]
+ for r in obj:
+ data = dump_range_binary(r, dump)
+ out.append(pack_len(len(data)))
+ out.append(data)
+ return b"".join(out)
+
+
+class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class MultirangeLoader(BaseMultirangeLoader[T]):
+ def load(self, data: Buffer) -> Multirange[T]:
+ if not data or data[0] != _START_INT:
+ raise e.DataError(
+ "malformed multirange starting with"
+ f" {bytes(data[:1]).decode('utf8', 'replace')}"
+ )
+
+ out = Multirange[T]()
+ if data == b"{}":
+ return out
+
+ pos = 1
+ data = data[pos:]
+ try:
+ while True:
+ r, pos = load_range_text(data, self._load)
+ out.append(r)
+
+ sep = data[pos] # can raise IndexError
+ if sep == _SEP_INT:
+ data = data[pos + 1 :]
+ continue
+ elif sep == _END_INT:
+ if len(data) == pos + 1:
+ return out
+ else:
+ raise e.DataError(
+ "malformed multirange: data after closing brace"
+ )
+ else:
+ raise e.DataError(
+ f"malformed multirange: found unexpected {chr(sep)}"
+ )
+
+ except IndexError:
+ raise e.DataError("malformed multirange: separator missing")
+
+ return out
+
+
+_SEP_INT = ord(",")
+_START_INT = ord("{")
+_END_INT = ord("}")
+
+
+class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Multirange[T]:
+ nelems = unpack_len(data, 0)[0]
+ pos = 4
+ out = Multirange[T]()
+ for i in range(nelems):
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ out.append(load_range_binary(data[pos : pos + length], self._load))
+ pos += length
+
+ if pos != len(data):
+ raise e.DataError("unexpected trailing data in multirange")
+
+ return out
+
+
+def register_multirange(
+ info: MultirangeInfo, context: Optional[AdaptContext] = None
+) -> None:
+ """Register the adapters to load and dump a multirange type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested multirange available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[MultirangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (MultirangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[MultirangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (MultirangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Binary dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Text loaders for builtin multirange types
+
+
+class Int4MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeLoader(MultirangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeLoader(MultirangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin multirange types
+
+
+class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Multirange, MultirangeBinaryDumper)
+ adapters.register_dumper(Multirange, MultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
+ adapters.register_loader("int4multirange", Int4MultirangeLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeLoader)
+ adapters.register_loader("datemultirange", DateMultirangeLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
+ adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
+ adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py
new file mode 100644
index 0000000..2f2c05b
--- /dev/null
+++ b/psycopg/psycopg/types/net.py
@@ -0,0 +1,206 @@
+"""
+Adapters for network types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, Type, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import ipaddress
+
+Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
+Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
+Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
+
+# These objects will be imported lazily
+ip_address: Callable[[str], Address] = None # type: ignore[assignment]
+ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
+ip_network: Callable[[str], Network] = None # type: ignore[assignment]
+IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
+IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
+IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
+IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
+IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
+IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
+
+PGSQL_AF_INET = 2
+PGSQL_AF_INET6 = 3
+IPV4_PREFIXLEN = 32
+IPV6_PREFIXLEN = 128
+
+
+class _LazyIpaddress:
+ def _ensure_module(self) -> None:
+ global ip_address, ip_interface, ip_network
+ global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
+ global IPv4Network, IPv6Network
+
+ if ip_address is None:
+ from ipaddress import ip_address, ip_interface, ip_network
+ from ipaddress import IPv4Address, IPv6Address
+ from ipaddress import IPv4Interface, IPv6Interface
+ from ipaddress import IPv4Network, IPv6Network
+
+
+class InterfaceDumper(Dumper):
+
+ oid = postgres.types["inet"].oid
+
+ def dump(self, obj: Interface) -> bytes:
+ return str(obj).encode()
+
+
+class NetworkDumper(Dumper):
+
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ return str(obj).encode()
+
+
+class _AIBinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["inet"].oid
+
+
+class AddressBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Address) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.max_prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InterfaceBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Interface) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.network.prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
+ """Either an address or an interface to inet
+
+ Used when looking up by oid.
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._ensure_module()
+
+ def dump(self, obj: Union[Address, Interface]) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ if isinstance(obj, (IPv4Interface, IPv6Interface)):
+ prefixlen = obj.network.prefixlen
+ else:
+ prefixlen = obj.max_prefixlen
+
+ head = bytes((family, prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class NetworkBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ packed = obj.network_address.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.prefixlen, 1, len(packed)))
+ return head + packed
+
+
+class _LazyIpaddressLoader(Loader, _LazyIpaddress):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._ensure_module()
+
+
+class InetLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ if b"/" in data:
+ return ip_interface(data.decode())
+ else:
+ return ip_address(data.decode())
+
+
+class InetBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ if prefix == IPV4_PREFIXLEN:
+ return IPv4Address(packed)
+ else:
+ return IPv4Interface((packed, prefix))
+ else:
+ if prefix == IPV6_PREFIXLEN:
+ return IPv6Address(packed)
+ else:
+ return IPv6Interface((packed, prefix))
+
+
+class CidrLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ return ip_network(data.decode())
+
+
+class CidrBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ return IPv4Network((packed, prefix))
+ else:
+ return IPv6Network((packed, prefix))
+
+ return ip_network(data.decode())
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
+ adapters.register_dumper(None, InetBinaryDumper)
+ adapters.register_loader("inet", InetLoader)
+ adapters.register_loader("inet", InetBinaryLoader)
+ adapters.register_loader("cidr", CidrLoader)
+ adapters.register_loader("cidr", CidrBinaryLoader)
diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py
new file mode 100644
index 0000000..2ab857c
--- /dev/null
+++ b/psycopg/psycopg/types/none.py
@@ -0,0 +1,25 @@
+"""
+Adapters for None.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ..abc import AdaptContext, NoneType
+from ..adapt import Dumper
+
+
+class NoneDumper(Dumper):
+ """
+ Not a complete dumper as it doesn't implement dump(), but it implements
+ quote(), so it can be used in sql composition.
+ """
+
+ def dump(self, obj: None) -> bytes:
+ raise NotImplementedError("NULL is passed to Postgres in other ways")
+
+ def quote(self, obj: None) -> bytes:
+ return b"NULL"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, NoneDumper)
diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py
new file mode 100644
index 0000000..1bd9329
--- /dev/null
+++ b/psycopg/psycopg/types/numeric.py
@@ -0,0 +1,515 @@
+"""
+Adapers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from math import log
+from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
+from decimal import Decimal, DefaultContext, Context
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from .._struct import pack_int2, pack_uint2, unpack_int2
+from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
+from .._struct import pack_int8, unpack_int8
+from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
+
+# Exposed here
+from .._wrappers import (
+ Int2 as Int2,
+ Int4 as Int4,
+ Int8 as Int8,
+ IntNumeric as IntNumeric,
+ Oid as Oid,
+ Float4 as Float4,
+ Float8 as Float8,
+)
+
+
+class _IntDumper(Dumper):
+ def dump(self, obj: Any) -> Buffer:
+ t = type(obj)
+ if t is not int:
+ # Convert to int in order to dump IntEnum correctly
+ if issubclass(t, int):
+ obj = int(obj)
+ else:
+ raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> Buffer:
+ value = self.dump(obj)
+ return value if obj >= 0 else b" " + value
+
+
+class _SpecialValuesDumper(Dumper):
+
+ _special: Dict[bytes, bytes] = {}
+
+ def dump(self, obj: Any) -> bytes:
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+
+ if value in self._special:
+ return self._special[value]
+
+ return value if obj >= 0 else b" " + value
+
+
+class FloatDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["float8"].oid
+
+ _special = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+ }
+
+
+class Float4Dumper(FloatDumper):
+ oid = postgres.types["float4"].oid
+
+
+class FloatBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["float8"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float8(obj)
+
+
+class Float4BinaryDumper(FloatBinaryDumper):
+
+ oid = postgres.types["float4"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float4(obj)
+
+
+class DecimalDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> bytes:
+ if obj.is_nan():
+ # cover NaN and sNaN
+ return b"NaN"
+ else:
+ return str(obj).encode()
+
+ _special = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+ }
+
+
+class Int2Dumper(_IntDumper):
+ oid = postgres.types["int2"].oid
+
+
+class Int4Dumper(_IntDumper):
+ oid = postgres.types["int4"].oid
+
+
+class Int8Dumper(_IntDumper):
+ oid = postgres.types["int8"].oid
+
+
+class IntNumericDumper(_IntDumper):
+ oid = postgres.types["numeric"].oid
+
+
+class OidDumper(_IntDumper):
+ oid = postgres.types["oid"].oid
+
+
+class IntDumper(Dumper):
+ def dump(self, obj: Any) -> bytes:
+ raise TypeError(
+ f"{type(self).__name__} is a dispatcher to other dumpers:"
+ " dump() is not supposed to be called"
+ )
+
+ def get_key(self, obj: int, format: PyFormat) -> type:
+ return self.upgrade(obj, format).cls
+
+ _int2_dumper = Int2Dumper(Int2)
+ _int4_dumper = Int4Dumper(Int4)
+ _int8_dumper = Int8Dumper(Int8)
+ _int_numeric_dumper = IntNumericDumper(IntNumeric)
+
+ def upgrade(self, obj: int, format: PyFormat) -> Dumper:
+ if -(2**31) <= obj < 2**31:
+ if -(2**15) <= obj < 2**15:
+ return self._int2_dumper
+ else:
+ return self._int4_dumper
+ else:
+ if -(2**63) <= obj < 2**63:
+ return self._int8_dumper
+ else:
+ return self._int_numeric_dumper
+
+
+class Int2BinaryDumper(Int2Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int2(obj)
+
+
+class Int4BinaryDumper(Int4Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int4(obj)
+
+
+class Int8BinaryDumper(Int8Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int8(obj)
+
+
+# Ratio between number of bits required to store a number and number of pg
+# decimal digits required.
+BIT_PER_PGDIGIT = log(2) / log(10_000)
+
+
+class IntNumericBinaryDumper(IntNumericDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> Buffer:
+ return dump_int_to_numeric_binary(obj)
+
+
+class OidBinaryDumper(OidDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_uint4(obj)
+
+
+class IntBinaryDumper(IntDumper):
+
+ format = Format.BINARY
+
+ _int2_dumper = Int2BinaryDumper(Int2)
+ _int4_dumper = Int4BinaryDumper(Int4)
+ _int8_dumper = Int8BinaryDumper(Int8)
+ _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
+
+
+class IntLoader(Loader):
+ def load(self, data: Buffer) -> int:
+ # it supports bytes directly
+ return int(data)
+
+
+class Int2BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int2(data)[0]
+
+
+class Int4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int4(data)[0]
+
+
+class Int8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int8(data)[0]
+
+
+class OidBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_uint4(data)[0]
+
+
+class FloatLoader(Loader):
+ def load(self, data: Buffer) -> float:
+ # it supports bytes directly
+ return float(data)
+
+
+class Float4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float4(data)[0]
+
+
+class Float8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float8(data)[0]
+
+
+class NumericLoader(Loader):
+ def load(self, data: Buffer) -> Decimal:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return Decimal(data.decode())
+
+
+DEC_DIGITS = 4 # decimal digits per Postgres "digit"
+NUMERIC_POS = 0x0000
+NUMERIC_NEG = 0x4000
+NUMERIC_NAN = 0xC000
+NUMERIC_PINF = 0xD000
+NUMERIC_NINF = 0xF000
+
+_decimal_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+
+class _ContextMap(DefaultDict[int, Context]):
+ """
+ Cache for decimal contexts to use when the precision requires it.
+
+ Note: if the default context is used (prec=28) you can get an invalid
+ operation or a rounding to 0:
+
+ - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
+ - Decimal(1000).shift(25) = Decimal('0')
+ - Decimal(1000).shift(30) raises InvalidOperation
+ """
+
+ def __missing__(self, key: int) -> Context:
+ val = Context(prec=key)
+ self[key] = val
+ return val
+
+
+_contexts = _ContextMap()
+for i in range(DefaultContext.prec):
+ _contexts[i] = DefaultContext
+
+_unpack_numeric_head = cast(
+ Callable[[Buffer], Tuple[int, int, int, int]],
+ struct.Struct("!HhHH").unpack_from,
+)
+_pack_numeric_head = cast(
+ Callable[[int, int, int, int], bytes],
+ struct.Struct("!HhHH").pack,
+)
+
+
+class NumericBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Decimal:
+ ndigits, weight, sign, dscale = _unpack_numeric_head(data)
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ val = 0
+ for i in range(8, len(data), 2):
+ val = val * 10_000 + data[i] * 0x100 + data[i + 1]
+
+ shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+ ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, ctx)
+ .shift(shift, ctx)
+ )
+ else:
+ try:
+ return _decimal_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
+
+
+NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
+NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
+NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
+
+
+class DecimalBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> Buffer:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+class NumericDumper(DecimalDumper):
+ def dump(self, obj: Union[Decimal, int]) -> bytes:
+ if isinstance(obj, int):
+ return str(obj).encode()
+ else:
+ return super().dump(obj)
+
+
+class NumericBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Union[Decimal, int]) -> Buffer:
+ if isinstance(obj, int):
+ return dump_int_to_numeric_binary(obj)
+ else:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
+ sign, digits, exp = obj.as_tuple()
+ if exp == "n" or exp == "N": # type: ignore[comparison-overlap]
+ return NUMERIC_NAN_BIN
+ elif exp == "F": # type: ignore[comparison-overlap]
+ return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+
+ # Weights of py digits into a pg digit according to their positions.
+ # Starting with an index wi != 0 is equivalent to prepending 0's to
+ # the digits tuple, but without really changing it.
+ weights = (1000, 100, 10, 1)
+ wi = 0
+
+ ndigits = nzdigits = len(digits)
+
+ # Find the last nonzero digit
+ while nzdigits > 0 and digits[nzdigits - 1] == 0:
+ nzdigits -= 1
+
+ if exp <= 0:
+ dscale = -exp
+ else:
+ dscale = 0
+ # align the py digits to the pg digits if there's some py exponent
+ ndigits += exp % DEC_DIGITS
+
+ if not nzdigits:
+ return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
+
+ # Equivalent of 0-padding left to align the py digits to the pg digits
+ # but without changing the digits tuple.
+ mod = (ndigits - dscale) % DEC_DIGITS
+ if mod:
+ wi = DEC_DIGITS - mod
+ ndigits += wi
+
+ tmp = nzdigits + wi
+ out = bytearray(
+ _pack_numeric_head(
+ tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
+ (ndigits + exp) // DEC_DIGITS - 1, # weight
+ NUMERIC_NEG if sign else NUMERIC_POS, # sign
+ dscale,
+ )
+ )
+
+ pgdigit = 0
+ for i in range(nzdigits):
+ pgdigit += weights[wi] * digits[i]
+ wi += 1
+ if wi >= DEC_DIGITS:
+ out += pack_uint2(pgdigit)
+ pgdigit = wi = 0
+
+ if pgdigit:
+ out += pack_uint2(pgdigit)
+
+ return out
+
+
+def dump_int_to_numeric_binary(obj: int) -> bytearray:
+ ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
+ out = bytearray(b"\x00\x00" * (ndigits + 4))
+ if obj < 0:
+ sign = NUMERIC_NEG
+ obj = -obj
+ else:
+ sign = NUMERIC_POS
+
+ out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
+ i = 8 + (ndigits - 1) * 2
+ while obj:
+ rem = obj % 10_000
+ obj //= 10_000
+ out[i : i + 2] = pack_uint2(rem)
+ i -= 2
+
+ return out
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(int, IntDumper)
+ adapters.register_dumper(int, IntBinaryDumper)
+ adapters.register_dumper(float, FloatDumper)
+ adapters.register_dumper(float, FloatBinaryDumper)
+ adapters.register_dumper(Int2, Int2Dumper)
+ adapters.register_dumper(Int4, Int4Dumper)
+ adapters.register_dumper(Int8, Int8Dumper)
+ adapters.register_dumper(IntNumeric, IntNumericDumper)
+ adapters.register_dumper(Oid, OidDumper)
+
+ # The binary dumper is currently some 30% slower, so default to text
+ # (see tests/scripts/testdec.py for a rough benchmark)
+ # Also, must be after IntNumericDumper
+ adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
+ adapters.register_dumper("decimal.Decimal", DecimalDumper)
+
+ # Used only by oid, can take both int and Decimal as input
+ adapters.register_dumper(None, NumericBinaryDumper)
+ adapters.register_dumper(None, NumericDumper)
+
+ adapters.register_dumper(Float4, Float4Dumper)
+ adapters.register_dumper(Float8, FloatDumper)
+ adapters.register_dumper(Int2, Int2BinaryDumper)
+ adapters.register_dumper(Int4, Int4BinaryDumper)
+ adapters.register_dumper(Int8, Int8BinaryDumper)
+ adapters.register_dumper(Oid, OidBinaryDumper)
+ adapters.register_dumper(Float4, Float4BinaryDumper)
+ adapters.register_dumper(Float8, FloatBinaryDumper)
+ adapters.register_loader("int2", IntLoader)
+ adapters.register_loader("int4", IntLoader)
+ adapters.register_loader("int8", IntLoader)
+ adapters.register_loader("oid", IntLoader)
+ adapters.register_loader("int2", Int2BinaryLoader)
+ adapters.register_loader("int4", Int4BinaryLoader)
+ adapters.register_loader("int8", Int8BinaryLoader)
+ adapters.register_loader("oid", OidBinaryLoader)
+ adapters.register_loader("float4", FloatLoader)
+ adapters.register_loader("float8", FloatLoader)
+ adapters.register_loader("float4", Float4BinaryLoader)
+ adapters.register_loader("float8", Float8BinaryLoader)
+ adapters.register_loader("numeric", NumericLoader)
+ adapters.register_loader("numeric", NumericBinaryLoader)
diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py
new file mode 100644
index 0000000..c418480
--- /dev/null
+++ b/psycopg/psycopg/types/range.py
@@ -0,0 +1,700 @@
+"""
+Support for range types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
+from typing import cast
+from decimal import Decimal
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import RangeInfo as RangeInfo # exported here
+
+RANGE_EMPTY = 0x01 # range is empty
+RANGE_LB_INC = 0x02 # lower bound is inclusive
+RANGE_UB_INC = 0x04 # upper bound is inclusive
+RANGE_LB_INF = 0x08 # lower bound is -infinity
+RANGE_UB_INF = 0x10 # upper bound is +infinity
+
+_EMPTY_HEAD = bytes([RANGE_EMPTY])
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+ """Python representation for a PostgreSQL range type.
+
+ :param lower: lower bound for the range. `!None` means unbound
+ :param upper: upper bound for the range. `!None` means unbound
+ :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
+ representing whether the lower or upper bounds are included
+ :param empty: if `!True`, the range is empty
+
+ """
+
+ __slots__ = ("_lower", "_upper", "_bounds")
+
+ def __init__(
+ self,
+ lower: Optional[T] = None,
+ upper: Optional[T] = None,
+ bounds: str = "[)",
+ empty: bool = False,
+ ):
+ if not empty:
+ if bounds not in ("[)", "(]", "()", "[]"):
+ raise ValueError("bound flags not valid: %r" % bounds)
+
+ self._lower = lower
+ self._upper = upper
+
+ # Make bounds consistent with infs
+ if lower is None and bounds[0] == "[":
+ bounds = "(" + bounds[1]
+ if upper is None and bounds[1] == "]":
+ bounds = bounds[0] + ")"
+
+ self._bounds = bounds
+ else:
+ self._lower = self._upper = None
+ self._bounds = ""
+
+ def __repr__(self) -> str:
+ if self._bounds:
+ args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}"
+ else:
+ args = "empty=True"
+
+ return f"{self.__class__.__name__}({args})"
+
+ def __str__(self) -> str:
+ if not self._bounds:
+ return "empty"
+
+ items = [
+ self._bounds[0],
+ str(self._lower),
+ ", ",
+ str(self._upper),
+ self._bounds[1],
+ ]
+ return "".join(items)
+
+ @property
+ def lower(self) -> Optional[T]:
+ """The lower bound of the range. `!None` if empty or unbound."""
+ return self._lower
+
+ @property
+ def upper(self) -> Optional[T]:
+ """The upper bound of the range. `!None` if empty or unbound."""
+ return self._upper
+
+ @property
+ def bounds(self) -> str:
+ """The bounds string (two characters from '[', '(', ']', ')')."""
+ return self._bounds
+
+ @property
+ def isempty(self) -> bool:
+ """`!True` if the range is empty."""
+ return not self._bounds
+
+ @property
+ def lower_inf(self) -> bool:
+ """`!True` if the range doesn't have a lower bound."""
+ if not self._bounds:
+ return False
+ return self._lower is None
+
+ @property
+ def upper_inf(self) -> bool:
+ """`!True` if the range doesn't have an upper bound."""
+ if not self._bounds:
+ return False
+ return self._upper is None
+
+ @property
+ def lower_inc(self) -> bool:
+ """`!True` if the lower bound is included in the range."""
+ if not self._bounds or self._lower is None:
+ return False
+ return self._bounds[0] == "["
+
+ @property
+ def upper_inc(self) -> bool:
+ """`!True` if the upper bound is included in the range."""
+ if not self._bounds or self._upper is None:
+ return False
+ return self._bounds[1] == "]"
+
+ def __contains__(self, x: T) -> bool:
+ if not self._bounds:
+ return False
+
+ if self._lower is not None:
+ if self._bounds[0] == "[":
+ # It doesn't seem that Python has an ABC for ordered types.
+ if x < self._lower: # type: ignore[operator]
+ return False
+ else:
+ if x <= self._lower: # type: ignore[operator]
+ return False
+
+ if self._upper is not None:
+ if self._bounds[1] == "]":
+ if x > self._upper: # type: ignore[operator]
+ return False
+ else:
+ if x >= self._upper: # type: ignore[operator]
+ return False
+
+ return True
+
+ def __bool__(self) -> bool:
+ return bool(self._bounds)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return False
+ return (
+ self._lower == other._lower
+ and self._upper == other._upper
+ and self._bounds == other._bounds
+ )
+
+ def __hash__(self) -> int:
+ return hash((self._lower, self._upper, self._bounds))
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return NotImplemented
+ for attr in ("_lower", "_upper", "_bounds"):
+ self_value = getattr(self, attr)
+ other_value = getattr(other, attr)
+ if self_value == other_value:
+ pass
+ elif self_value is None:
+ return True
+ elif other_value is None:
+ return False
+ else:
+ return cast(bool, self_value < other_value)
+ return False
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if isinstance(other, Range):
+ return other < self
+ else:
+ return NotImplemented
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+ def __getstate__(self) -> Dict[str, Any]:
+ return {
+ slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
+ }
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ for slot, value in state.items():
+ setattr(self, slot, value)
+
+
+# Subclasses to specify a specific subtype. Usually not needed: only needed
+# in binary copy, where switching to text is not an option.
+
+
+class Int4Range(Range[int]):
+ pass
+
+
+class Int8Range(Range[int]):
+ pass
+
+
+class NumericRange(Range[Decimal]):
+ pass
+
+
+class DateRange(Range[date]):
+ pass
+
+
+class TimestampRange(Range[datetime]):
+ pass
+
+
+class TimestamptzRange(Range[datetime]):
+ pass
+
+
+class BaseRangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return RangeDumper(self.cls)
+
+ dumper: BaseRangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = RangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_range_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_range_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Range[Any]) -> Any:
+ """
+ Return a member representative of the range
+ """
+ rv = obj.lower
+ return rv if rv is not None else obj.upper
+
+ def _get_range_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class RangeDumper(BaseRangeDumper):
+ """
+ Dumper for range types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_text(obj, dump)
+
+
+def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if obj.isempty:
+ return b"empty"
+
+ parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
+
+ def dump_item(item: Any) -> Buffer:
+ ad = dump(item)
+ if not ad:
+ return b'""'
+ elif _re_needs_quotes.search(ad):
+ return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"'
+ else:
+ return ad
+
+ if obj.lower is not None:
+ parts.append(dump_item(obj.lower))
+
+ parts.append(b",")
+
+ if obj.upper is not None:
+ parts.append(dump_item(obj.upper))
+
+ parts.append(b"]" if obj.upper_inc else b")")
+
+ return b"".join(parts)
+
+
+_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]')
+_re_esc = re.compile(rb"([\\\"])")
+
+
+class RangeBinaryDumper(BaseRangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_binary(obj, dump)
+
+
+def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if not obj:
+ return _EMPTY_HEAD
+
+ out = bytearray([0]) # will replace the head later
+
+ head = 0
+ if obj.lower_inc:
+ head |= RANGE_LB_INC
+ if obj.upper_inc:
+ head |= RANGE_UB_INC
+
+ if obj.lower is not None:
+ data = dump(obj.lower)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_LB_INF
+
+ if obj.upper is not None:
+ data = dump(obj.upper)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_UB_INF
+
+ out[0] = head
+ return out
+
+
+def fail_dump(obj: Any) -> Buffer:
+ raise e.InternalError("trying to dump a range element without information")
+
+
+class BaseRangeLoader(RecursiveLoader, Generic[T]):
+ """Generic loader for a range.
+
+ Subclasses must specify the oid of the subtype and the class to load.
+ """
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class RangeLoader(BaseRangeLoader[T]):
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_text(data, self._load)[0]
+
+
+def load_range_text(
+ data: Buffer, load: Callable[[Buffer], Any]
+) -> Tuple[Range[Any], int]:
+ if data == b"empty":
+ return Range(empty=True), 5
+
+ m = _re_range.match(data)
+ if m is None:
+ raise e.DataError(
+ f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
+ )
+
+ lower = None
+ item = m.group(3)
+ if item is None:
+ item = m.group(2)
+ if item is not None:
+ lower = load(_re_undouble.sub(rb"\1", item))
+ else:
+ lower = load(item)
+
+ upper = None
+ item = m.group(5)
+ if item is None:
+ item = m.group(4)
+ if item is not None:
+ upper = load(_re_undouble.sub(rb"\1", item))
+ else:
+ upper = load(item)
+
+ bounds = (m.group(1) + m.group(6)).decode()
+
+ return Range(lower, upper, bounds), m.end()
+
+
+_re_range = re.compile(
+ rb"""
+ ( \(|\[ ) # lower bound flag
+ (?: # lower bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^",]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ,
+ (?: # upper bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^"\)\]]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ( \)|\] ) # upper bound flag
+ """,
+ re.VERBOSE,
+)
+
+_re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RangeBinaryLoader(BaseRangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_binary(data, self._load)
+
+
+def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
+ head = data[0]
+ if head & RANGE_EMPTY:
+ return Range(empty=True)
+
+ lb = "[" if head & RANGE_LB_INC else "("
+ ub = "]" if head & RANGE_UB_INC else ")"
+
+ pos = 1 # after the head
+ if head & RANGE_LB_INF:
+ min = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ min = load(data[pos : pos + length])
+ pos += length
+
+ if head & RANGE_UB_INF:
+ max = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ max = load(data[pos : pos + length])
+ pos += length
+
+ return Range(min, max, lb + ub)
+
+
+def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump a range type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested range available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[RangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (RangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[RangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (RangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeDumper(RangeDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeDumper(RangeDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeDumper(RangeDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeDumper(RangeDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeDumper(RangeDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeDumper(RangeDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Binary dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Text loaders for builtin range types
+
+
+class Int4RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeLoader(RangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeLoader(RangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin range types
+
+
+class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeBinaryLoader(RangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Range, RangeBinaryDumper)
+ adapters.register_dumper(Range, RangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeDumper)
+ adapters.register_dumper(Int8Range, Int8RangeDumper)
+ adapters.register_dumper(NumericRange, NumericRangeDumper)
+ adapters.register_dumper(DateRange, DateRangeDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeBinaryDumper)
+ adapters.register_dumper(Int8Range, Int8RangeBinaryDumper)
+ adapters.register_dumper(NumericRange, NumericRangeBinaryDumper)
+ adapters.register_dumper(DateRange, DateRangeBinaryDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper)
+ adapters.register_loader("int4range", Int4RangeLoader)
+ adapters.register_loader("int8range", Int8RangeLoader)
+ adapters.register_loader("numrange", NumericRangeLoader)
+ adapters.register_loader("daterange", DateRangeLoader)
+ adapters.register_loader("tsrange", TimestampRangeLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeLoader)
+ adapters.register_loader("int4range", Int4RangeBinaryLoader)
+ adapters.register_loader("int8range", Int8RangeBinaryLoader)
+ adapters.register_loader("numrange", NumericRangeBinaryLoader)
+ adapters.register_loader("daterange", DateRangeBinaryLoader)
+ adapters.register_loader("tsrange", TimestampRangeBinaryLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader)
diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py
new file mode 100644
index 0000000..e99f256
--- /dev/null
+++ b/psycopg/psycopg/types/shapely.py
@@ -0,0 +1,75 @@
+"""
+Adapters for PostGIS geometries
+"""
+
+from typing import Optional
+
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Dumper, Loader
+from ..pq import Format
+from .._typeinfo import TypeInfo
+
+
+try:
+ from shapely.wkb import loads, dumps
+ from shapely.geometry.base import BaseGeometry
+
+except ImportError:
+ raise ImportError(
+ "The module psycopg.types.shapely requires the package 'Shapely'"
+ " to be installed"
+ )
+
+
+class GeometryBinaryLoader(Loader):
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "BaseGeometry":
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return loads(data)
+
+
+class GeometryLoader(Loader):
+ def load(self, data: Buffer) -> "BaseGeometry":
+ # it's a hex string in binary
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return loads(data.decode(), hex=True)
+
+
+class BaseGeometryBinaryDumper(Dumper):
+ format = Format.BINARY
+
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj) # type: ignore
+
+
+class BaseGeometryDumper(Dumper):
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj, hex=True).encode() # type: ignore
+
+
+def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register Shapely dumper and loaders."""
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'postgis' extension loaded?")
+
+ info.register(context)
+ adapters = context.adapters if context else postgres.adapters
+
+ class GeometryDumper(BaseGeometryDumper):
+ oid = info.oid
+
+ class GeometryBinaryDumper(BaseGeometryBinaryDumper):
+ oid = info.oid
+
+ adapters.register_loader(info.oid, GeometryBinaryLoader)
+ adapters.register_loader(info.oid, GeometryLoader)
+ # Default binary dump
+ adapters.register_dumper(BaseGeometry, GeometryDumper)
+ adapters.register_dumper(BaseGeometry, GeometryBinaryDumper)
diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py
new file mode 100644
index 0000000..cd5360d
--- /dev/null
+++ b/psycopg/psycopg/types/string.py
@@ -0,0 +1,239 @@
+"""
+Adapters for textual types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Optional, Union, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format, Escaping
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from ..errors import DataError
+from .._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from ..pq.abc import Escaping as EscapingProto
+
+
+class _BaseStrDumper(Dumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else "utf-8"
+
+
+class _StrBinaryDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in binary format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ format = Format.BINARY
+
+ def dump(self, obj: str) -> bytes:
+ # the server will raise DataError subclass if the string contains 0x00
+ return obj.encode(self._encoding)
+
+
+class _StrDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in text format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ def dump(self, obj: str) -> bytes:
+ if "\x00" in obj:
+ raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
+ else:
+ return obj.encode(self._encoding)
+
+
+# The next are concrete dumpers, each one specifying the oid they dump to.
+
+
+class StrBinaryDumper(_StrBinaryDumper):
+
+ oid = postgres.types["text"].oid
+
+
+class StrBinaryDumperVarchar(_StrBinaryDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrBinaryDumperName(_StrBinaryDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumper(_StrDumper):
+ """
+ Dumper for strings in text format to the text oid.
+
+ Note that this dumper is not used by default because the type is too strict
+ and PostgreSQL would require an explicit casts to everything that is not a
+ text field. However it is useful where the unknown oid is ambiguous and the
+ text oid is required, for instance with variadic functions.
+ """
+
+ oid = postgres.types["text"].oid
+
+
+class StrDumperVarchar(_StrDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrDumperName(_StrDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumperUnknown(_StrDumper):
+ """
+ Dumper for strings in text format to the unknown oid.
+
+ This dumper is the default dumper for strings and allows to use Python
+ strings to represent almost every data type. In a few places, however, the
+ unknown oid is not accepted (for instance in variadic functions such as
+ 'concat()'). In that case either a cast on the placeholder ('%s::text') or
+ the StrTextDumper should be used.
+ """
+
+ pass
+
+
+class TextLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else ""
+
+ def load(self, data: Buffer) -> Union[bytes, str]:
+ if self._encoding:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return data.decode(self._encoding)
+ else:
+ # return bytes for SQL_ASCII db
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return data
+
+
+class TextBinaryLoader(TextLoader):
+
+ format = Format.BINARY
+
+
+class BytesDumper(Dumper):
+
+ oid = postgres.types["bytea"].oid
+ _qprefix = b""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._esc = Escaping(self.connection.pgconn if self.connection else None)
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return self._esc.escape_bytea(obj)
+
+ def quote(self, obj: Buffer) -> bytes:
+ escaped = self.dump(obj)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self.connection:
+ if not self._qprefix:
+ scs = self.connection.pgconn.parameter_status(
+ b"standard_conforming_strings"
+ )
+ self._qprefix = b"'" if scs == b"on" else b" E'"
+
+ return self._qprefix + escaped + b"'"
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ rv: bytes = b" E'" + escaped + b"'"
+ if self._esc.escape_bytea(b"\x00") == b"\\000":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+
+class BytesBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bytea"].oid
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return obj
+
+
+class ByteaLoader(Loader):
+
+ _escaping: "EscapingProto"
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if not hasattr(self.__class__, "_escaping"):
+ self.__class__._escaping = Escaping()
+
+ def load(self, data: Buffer) -> bytes:
+ return self._escaping.unescape_bytea(data)
+
+
+class ByteaBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Buffer:
+ return data
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+
+ # NOTE: the order the dumpers are registered is relevant. The last one
+ # registered becomes the default for each type. Usually, binary is the
+ # default dumper. For text we use the text dumper as default because it
+ # plays the role of unknown, and it can be cast automatically to other
+ # types. However, before that, we register dumper with 'text', 'varchar',
+ # 'name' oids, which will be used when a text dumper is looked up by oid.
+ adapters.register_dumper(str, StrBinaryDumperName)
+ adapters.register_dumper(str, StrBinaryDumperVarchar)
+ adapters.register_dumper(str, StrBinaryDumper)
+ adapters.register_dumper(str, StrDumperName)
+ adapters.register_dumper(str, StrDumperVarchar)
+ adapters.register_dumper(str, StrDumper)
+ adapters.register_dumper(str, StrDumperUnknown)
+
+ adapters.register_loader(postgres.INVALID_OID, TextLoader)
+ adapters.register_loader("bpchar", TextLoader)
+ adapters.register_loader("name", TextLoader)
+ adapters.register_loader("text", TextLoader)
+ adapters.register_loader("varchar", TextLoader)
+ adapters.register_loader('"char"', TextLoader)
+ adapters.register_loader("bpchar", TextBinaryLoader)
+ adapters.register_loader("name", TextBinaryLoader)
+ adapters.register_loader("text", TextBinaryLoader)
+ adapters.register_loader("varchar", TextBinaryLoader)
+ adapters.register_loader('"char"', TextBinaryLoader)
+
+ adapters.register_dumper(bytes, BytesDumper)
+ adapters.register_dumper(bytearray, BytesDumper)
+ adapters.register_dumper(memoryview, BytesDumper)
+ adapters.register_dumper(bytes, BytesBinaryDumper)
+ adapters.register_dumper(bytearray, BytesBinaryDumper)
+ adapters.register_dumper(memoryview, BytesBinaryDumper)
+
+ adapters.register_loader("bytea", ByteaLoader)
+ adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader)
+ adapters.register_loader("bytea", ByteaBinaryLoader)
diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py
new file mode 100644
index 0000000..f92354c
--- /dev/null
+++ b/psycopg/psycopg/types/uuid.py
@@ -0,0 +1,65 @@
+"""
+Adapters for the UUID type.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import uuid
+
+# Importing the uuid module is slow, so import it only on request.
+UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment]
+
+
+class UUIDDumper(Dumper):
+
+ oid = postgres.types["uuid"].oid
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.hex.encode()
+
+
+class UUIDBinaryDumper(UUIDDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.bytes
+
+
+class UUIDLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ global UUID
+ if UUID is None:
+ from uuid import UUID
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(data.decode())
+
+
+class UUIDBinaryLoader(UUIDLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(bytes=data)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("uuid.UUID", UUIDDumper)
+ adapters.register_dumper("uuid.UUID", UUIDBinaryDumper)
+ adapters.register_loader("uuid", UUIDLoader)
+ adapters.register_loader("uuid", UUIDBinaryLoader)
diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py
new file mode 100644
index 0000000..a98bc35
--- /dev/null
+++ b/psycopg/psycopg/version.py
@@ -0,0 +1,14 @@
+"""
+psycopg distribution version file.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+
+# STOP AND READ! if you change:
+__version__ = "3.1.7"
+# also change:
+# - `docs/news.rst` to declare this as the current version or an unreleased one
+# - `psycopg_c/psycopg_c/version.py` to the same version.
diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py
new file mode 100644
index 0000000..7abfc58
--- /dev/null
+++ b/psycopg/psycopg/waiting.py
@@ -0,0 +1,331 @@
+"""
+Code concerned with waiting in different contexts (blocking, async, etc).
+
+These functions are designed to consume the generators returned by the
+`generators` module function and to return their final value.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+import os
+import select
+import selectors
+from typing import Dict, Optional
+from asyncio import get_event_loop, wait_for, Event, TimeoutError
+from selectors import DefaultSelector
+
+from . import errors as e
+from .abc import RV, PQGen, PQGenConn, WaitFunc
+from ._enums import Wait as Wait, Ready as Ready # re-exported
+from ._cmodule import _psycopg
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+
+def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Consume `!gen`, scheduling `fileno` for completion when it is reported to
+ block. Once ready again send the ready state back to `!gen`.
+ """
+ try:
+ s = next(gen)
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = None
+ while not rlist:
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ # note: this line should require a cast, but mypy doesn't complain
+ ready: Ready = rlist[0][1]
+ assert s & ready
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a connection generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ if not rlist:
+ raise e.ConnectionTimeout("connection timeout expired")
+ ready: Ready = rlist[0][1] # type: ignore[assignment]
+ fileno, s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
+ """
+ Coroutine waiting for a generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but exposing an `asyncio` interface.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready |= state # type: ignore[assignment]
+ ev.set()
+
+ try:
+ s = next(gen)
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await ev.wait()
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Coroutine waiting for a connection generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready = state
+ ev.set()
+
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await wait_for(ev.wait(), timeout)
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ fileno, s = gen.send(ready)
+
+ except TimeoutError:
+ raise e.ConnectionTimeout("connection timeout expired")
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+# Specialised implementation of wait functions.
+
+
+def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using select where supported.
+ """
+ try:
+ s = next(gen)
+
+ empty = ()
+ fnlist = (fileno,)
+ while True:
+ rl, wl, xl = select.select(
+ fnlist if s & WAIT_R else empty,
+ fnlist if s & WAIT_W else empty,
+ fnlist,
+ timeout,
+ )
+ ready = 0
+ if rl:
+ ready = READY_R
+ if wl:
+ ready |= READY_W
+ if not ready:
+ continue
+ # assert s & ready
+ s = gen.send(ready) # type: ignore
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+poll_evmasks: Dict[Wait, int]
+
+if hasattr(selectors, "EpollSelector"):
+ poll_evmasks = {
+ WAIT_R: select.EPOLLONESHOT | select.EPOLLIN,
+ WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT,
+ WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT,
+ }
+else:
+ poll_evmasks = {}
+
+
+def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using epoll where supported.
+
+ Parameters are like for `wait()`. If it is detected that the best selector
+ strategy is `epoll` then this function will be used instead of `wait`.
+
+ See also: https://linux.die.net/man/2/epoll_ctl
+ """
+ try:
+ s = next(gen)
+
+ if timeout is None or timeout < 0:
+ timeout = 0
+ else:
+ timeout = int(timeout * 1000.0)
+
+ with select.epoll() as epoll:
+ evmask = poll_evmasks[s]
+ epoll.register(fileno, evmask)
+ while True:
+ fileevs = None
+ while not fileevs:
+ fileevs = epoll.poll(timeout)
+ ev = fileevs[0][1]
+ ready = 0
+ if ev & ~select.EPOLLOUT:
+ ready = READY_R
+ if ev & ~select.EPOLLIN:
+ ready |= READY_W
+ # assert s & ready
+ s = gen.send(ready)
+ evmask = poll_evmasks[s]
+ epoll.modify(fileno, evmask)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+if _psycopg:
+ wait_c = _psycopg.wait_c
+
+
+# Choose the best wait strategy for the platform.
+#
+# the selectors objects have a generic interface but come with some overhead,
+# so we also offer more finely tuned implementations.
+
+wait: WaitFunc
+
+# Allow the user to choose a specific function for testing
+if "PSYCOPG_WAIT_FUNC" in os.environ:
+ fname = os.environ["PSYCOPG_WAIT_FUNC"]
+ if not fname.startswith("wait_") or fname not in globals():
+ raise ImportError(
+ "PSYCOPG_WAIT_FUNC should be the name of an available wait function;"
+ f" got {fname!r}"
+ )
+ wait = globals()[fname]
+
+elif _psycopg:
+ wait = wait_c
+
+elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
+ # On Windows, SelectSelector should be the default.
+ wait = wait_select
+
+elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None):
+ # NOTE: select seems more performing than epoll. It is admittedly unlikely
+ # that a platform has epoll but not select, so maybe we could kill
+ # wait_epoll altogether(). More testing to do.
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll
+
+elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None):
+ # wait_select is faster than wait_selector, probably because of less overhead
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector
+
+else:
+ wait = wait_selector
diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml
new file mode 100644
index 0000000..21e410c
--- /dev/null
+++ b/psycopg/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37"]
+build-backend = "setuptools.build_meta"
diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg
new file mode 100644
index 0000000..fdcb612
--- /dev/null
+++ b/psycopg/setup.cfg
@@ -0,0 +1,47 @@
+[metadata]
+name = psycopg
+description = PostgreSQL database adapter for Python
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+packages = find:
+zip_safe = False
+install_requires =
+ backports.zoneinfo >= 0.2.0; python_version < "3.9"
+ typing-extensions >= 4.1
+ tzdata; sys_platform == "win32"
+
+[options.package_data]
+psycopg = py.typed
diff --git a/psycopg/setup.py b/psycopg/setup.py
new file mode 100644
index 0000000..90d4380
--- /dev/null
+++ b/psycopg/setup.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - pure Python package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+from setuptools import setup
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+# Only for release 3.1.7. Not building binary packages because Scaleway
+# has no runner available, but psycopg-binary 3.1.6 should work as well
+# as the only change is in rows.py.
+version = "3.1.7"
+ext_versions = ">= 3.1.6, <= 3.1.7"
+
+extras_require = {
+ # Install the C extension module (requires dev tools)
+ "c": [
+ f"psycopg-c {ext_versions}",
+ ],
+ # Install the stand-alone C extension module
+ "binary": [
+ f"psycopg-binary {ext_versions}",
+ ],
+ # Install the connection pool
+ "pool": [
+ "psycopg-pool",
+ ],
+ # Requirements to run the test suite
+ "test": [
+ "mypy >= 0.990",
+ "pproxy >= 2.7",
+ "pytest >= 6.2.5",
+ "pytest-asyncio >= 0.17",
+ "pytest-cov >= 3.0",
+ "pytest-randomly >= 3.10",
+ ],
+ # Requirements needed for development
+ "dev": [
+ "black >= 22.3.0",
+ "dnspython >= 2.1",
+ "flake8 >= 4.0",
+ "mypy >= 0.990",
+ "types-setuptools >= 57.4",
+ "wheel >= 0.37",
+ ],
+ # Requirements needed to build the documentation
+ "docs": [
+ "Sphinx >= 5.0",
+ "furo == 2022.6.21",
+ "sphinx-autobuild >= 2021.3.14",
+ "sphinx-autodoc-typehints >= 1.12",
+ ],
+}
+
+setup(
+ version=version,
+ extras_require=extras_require,
+)