summaryrefslogtreecommitdiffstats
path: root/psycopg
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--psycopg/.flake86
-rw-r--r--psycopg/LICENSE.txt165
-rw-r--r--psycopg/README.rst31
-rw-r--r--psycopg/psycopg/__init__.py110
-rw-r--r--psycopg/psycopg/_adapters_map.py289
-rw-r--r--psycopg/psycopg/_cmodule.py24
-rw-r--r--psycopg/psycopg/_column.py143
-rw-r--r--psycopg/psycopg/_compat.py72
-rw-r--r--psycopg/psycopg/_dns.py223
-rw-r--r--psycopg/psycopg/_encodings.py170
-rw-r--r--psycopg/psycopg/_enums.py79
-rw-r--r--psycopg/psycopg/_pipeline.py288
-rw-r--r--psycopg/psycopg/_preparing.py198
-rw-r--r--psycopg/psycopg/_queries.py375
-rw-r--r--psycopg/psycopg/_struct.py57
-rw-r--r--psycopg/psycopg/_tpc.py116
-rw-r--r--psycopg/psycopg/_transform.py350
-rw-r--r--psycopg/psycopg/_typeinfo.py461
-rw-r--r--psycopg/psycopg/_tz.py44
-rw-r--r--psycopg/psycopg/_wrappers.py137
-rw-r--r--psycopg/psycopg/abc.py266
-rw-r--r--psycopg/psycopg/adapt.py162
-rw-r--r--psycopg/psycopg/client_cursor.py95
-rw-r--r--psycopg/psycopg/connection.py1031
-rw-r--r--psycopg/psycopg/connection_async.py436
-rw-r--r--psycopg/psycopg/conninfo.py378
-rw-r--r--psycopg/psycopg/copy.py904
-rw-r--r--psycopg/psycopg/crdb/__init__.py19
-rw-r--r--psycopg/psycopg/crdb/_types.py163
-rw-r--r--psycopg/psycopg/crdb/connection.py186
-rw-r--r--psycopg/psycopg/cursor.py921
-rw-r--r--psycopg/psycopg/cursor_async.py250
-rw-r--r--psycopg/psycopg/dbapi20.py112
-rw-r--r--psycopg/psycopg/errors.py1535
-rw-r--r--psycopg/psycopg/generators.py320
-rw-r--r--psycopg/psycopg/postgres.py125
-rw-r--r--psycopg/psycopg/pq/__init__.py133
-rw-r--r--psycopg/psycopg/pq/_debug.py106
-rw-r--r--psycopg/psycopg/pq/_enums.py249
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.py804
-rw-r--r--psycopg/psycopg/pq/_pq_ctypes.pyi216
-rw-r--r--psycopg/psycopg/pq/abc.py385
-rw-r--r--psycopg/psycopg/pq/misc.py146
-rw-r--r--psycopg/psycopg/pq/pq_ctypes.py1086
-rw-r--r--psycopg/psycopg/py.typed0
-rw-r--r--psycopg/psycopg/rows.py256
-rw-r--r--psycopg/psycopg/server_cursor.py479
-rw-r--r--psycopg/psycopg/sql.py467
-rw-r--r--psycopg/psycopg/transaction.py290
-rw-r--r--psycopg/psycopg/types/__init__.py11
-rw-r--r--psycopg/psycopg/types/array.py464
-rw-r--r--psycopg/psycopg/types/bool.py51
-rw-r--r--psycopg/psycopg/types/composite.py290
-rw-r--r--psycopg/psycopg/types/datetime.py754
-rw-r--r--psycopg/psycopg/types/enum.py177
-rw-r--r--psycopg/psycopg/types/hstore.py131
-rw-r--r--psycopg/psycopg/types/json.py232
-rw-r--r--psycopg/psycopg/types/multirange.py514
-rw-r--r--psycopg/psycopg/types/net.py206
-rw-r--r--psycopg/psycopg/types/none.py25
-rw-r--r--psycopg/psycopg/types/numeric.py515
-rw-r--r--psycopg/psycopg/types/range.py700
-rw-r--r--psycopg/psycopg/types/shapely.py75
-rw-r--r--psycopg/psycopg/types/string.py239
-rw-r--r--psycopg/psycopg/types/uuid.py65
-rw-r--r--psycopg/psycopg/version.py14
-rw-r--r--psycopg/psycopg/waiting.py331
-rw-r--r--psycopg/pyproject.toml3
-rw-r--r--psycopg/setup.cfg47
-rw-r--r--psycopg/setup.py66
-rw-r--r--psycopg_c/.flake83
-rw-r--r--psycopg_c/LICENSE.txt165
-rw-r--r--psycopg_c/README-binary.rst29
-rw-r--r--psycopg_c/README.rst33
-rw-r--r--psycopg_c/psycopg_c/.gitignore4
-rw-r--r--psycopg_c/psycopg_c/__init__.py14
-rw-r--r--psycopg_c/psycopg_c/_psycopg.pyi84
-rw-r--r--psycopg_c/psycopg_c/_psycopg.pyx48
-rw-r--r--psycopg_c/psycopg_c/_psycopg/__init__.pxd9
-rw-r--r--psycopg_c/psycopg_c/_psycopg/adapt.pyx171
-rw-r--r--psycopg_c/psycopg_c/_psycopg/copy.pyx340
-rw-r--r--psycopg_c/psycopg_c/_psycopg/endian.pxd155
-rw-r--r--psycopg_c/psycopg_c/_psycopg/generators.pyx276
-rw-r--r--psycopg_c/psycopg_c/_psycopg/oids.pxd92
-rw-r--r--psycopg_c/psycopg_c/_psycopg/transform.pyx640
-rw-r--r--psycopg_c/psycopg_c/_psycopg/waiting.pyx197
-rw-r--r--psycopg_c/psycopg_c/pq.pxd78
-rw-r--r--psycopg_c/psycopg_c/pq.pyx38
-rw-r--r--psycopg_c/psycopg_c/pq/__init__.pxd9
-rw-r--r--psycopg_c/psycopg_c/pq/conninfo.pyx61
-rw-r--r--psycopg_c/psycopg_c/pq/escaping.pyx132
-rw-r--r--psycopg_c/psycopg_c/pq/libpq.pxd321
-rw-r--r--psycopg_c/psycopg_c/pq/pgcancel.pyx32
-rw-r--r--psycopg_c/psycopg_c/pq/pgconn.pyx733
-rw-r--r--psycopg_c/psycopg_c/pq/pgresult.pyx157
-rw-r--r--psycopg_c/psycopg_c/pq/pqbuffer.pyx111
-rw-r--r--psycopg_c/psycopg_c/py.typed0
-rw-r--r--psycopg_c/psycopg_c/types/array.pyx276
-rw-r--r--psycopg_c/psycopg_c/types/bool.pyx78
-rw-r--r--psycopg_c/psycopg_c/types/datetime.pyx1136
-rw-r--r--psycopg_c/psycopg_c/types/numeric.pyx715
-rw-r--r--psycopg_c/psycopg_c/types/numutils.c243
-rw-r--r--psycopg_c/psycopg_c/types/string.pyx315
-rw-r--r--psycopg_c/psycopg_c/version.py11
-rw-r--r--psycopg_c/pyproject.toml3
-rw-r--r--psycopg_c/setup.cfg57
-rw-r--r--psycopg_c/setup.py110
-rw-r--r--psycopg_pool/.flake83
-rw-r--r--psycopg_pool/LICENSE.txt165
-rw-r--r--psycopg_pool/README.rst24
-rw-r--r--psycopg_pool/psycopg_pool/__init__.py22
-rw-r--r--psycopg_pool/psycopg_pool/_compat.py51
-rw-r--r--psycopg_pool/psycopg_pool/base.py230
-rw-r--r--psycopg_pool/psycopg_pool/errors.py25
-rw-r--r--psycopg_pool/psycopg_pool/null_pool.py159
-rw-r--r--psycopg_pool/psycopg_pool/null_pool_async.py122
-rw-r--r--psycopg_pool/psycopg_pool/pool.py839
-rw-r--r--psycopg_pool/psycopg_pool/pool_async.py784
-rw-r--r--psycopg_pool/psycopg_pool/py.typed0
-rw-r--r--psycopg_pool/psycopg_pool/sched.py177
-rw-r--r--psycopg_pool/psycopg_pool/version.py13
-rw-r--r--psycopg_pool/setup.cfg45
-rw-r--r--psycopg_pool/setup.py26
123 files changed, 29329 insertions, 0 deletions
diff --git a/psycopg/.flake8 b/psycopg/.flake8
new file mode 100644
index 0000000..67fb024
--- /dev/null
+++ b/psycopg/.flake8
@@ -0,0 +1,6 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
+per-file-ignores =
+ # Autogenerated section
+ psycopg/errors.py: E125, E128, E302
diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg/README.rst b/psycopg/README.rst
new file mode 100644
index 0000000..45eeac3
--- /dev/null
+++ b/psycopg/README.rst
@@ -0,0 +1,31 @@
+Psycopg 3: PostgreSQL database adapter for Python
+=================================================
+
+Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python.
+
+This distribution contains the pure Python package ``psycopg``.
+
+
+Installation
+------------
+
+In short, run the following::
+
+ pip install --upgrade pip # to upgrade pip
+ pip install "psycopg[binary,pool]" # to install package and dependencies
+
+If something goes wrong, and for more information about installation, please
+check out the `Installation documentation`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html#
+
+
+Hacking
+-------
+
+For development information check out `the project readme`__.
+
+.. __: https://github.com/psycopg/psycopg#readme
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py
new file mode 100644
index 0000000..baadf30
--- /dev/null
+++ b/psycopg/psycopg/__init__.py
@@ -0,0 +1,110 @@
+"""
+psycopg -- PostgreSQL database adapter for Python
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from . import pq # noqa: F401 import early to stabilize side effects
+from . import types
+from . import postgres
+from ._tpc import Xid
+from .copy import Copy, AsyncCopy
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from .errors import Warning, Error, InterfaceError, DatabaseError
+from .errors import DataError, OperationalError, IntegrityError
+from .errors import InternalError, ProgrammingError, NotSupportedError
+from ._column import Column
+from .conninfo import ConnectionInfo
+from ._pipeline import Pipeline, AsyncPipeline
+from .connection import BaseConnection, Connection, Notify
+from .transaction import Rollback, Transaction, AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor, ServerCursor
+from .client_cursor import AsyncClientCursor, ClientCursor
+from .connection_async import AsyncConnection
+
+from . import dbapi20
+from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
+from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
+from .dbapi20 import Timestamp, TimestampFromTicks
+
+from .version import __version__ as __version__ # noqa: F401
+
+# Set the logger to a quiet default, can be enabled if needed
+logger = logging.getLogger("psycopg")
+if logger.level == logging.NOTSET:
+ logger.setLevel(logging.WARNING)
+
+# DBAPI compliance
+connect = Connection.connect
+apilevel = "2.0"
+threadsafety = 2
+paramstyle = "pyformat"
+
+# register default adapters for PostgreSQL
+adapters = postgres.adapters # exposed by the package
+postgres.register_default_adapters(adapters)
+
+# After the default ones, because these can deal with the bytea oid better
+dbapi20.register_dbapi20_adapters(adapters)
+
+# Must come after all the types have been registered
+types.array.register_all_arrays(adapters)
+
+# Note: defining the exported methods helps both Sphynx in documenting that
+# this is the canonical place to obtain them and should be used by MyPy too,
+# so that function signatures are consistent with the documentation.
+__all__ = [
+ "AsyncClientCursor",
+ "AsyncConnection",
+ "AsyncCopy",
+ "AsyncCursor",
+ "AsyncPipeline",
+ "AsyncServerCursor",
+ "AsyncTransaction",
+ "BaseConnection",
+ "ClientCursor",
+ "Column",
+ "Connection",
+ "ConnectionInfo",
+ "Copy",
+ "Cursor",
+ "IsolationLevel",
+ "Notify",
+ "Pipeline",
+ "Rollback",
+ "ServerCursor",
+ "Transaction",
+ "Xid",
+ # DBAPI exports
+ "connect",
+ "apilevel",
+ "threadsafety",
+ "paramstyle",
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ # DBAPI type constructors and singletons
+ "Binary",
+ "Date",
+ "DateFromTicks",
+ "Time",
+ "TimeFromTicks",
+ "Timestamp",
+ "TimestampFromTicks",
+ "BINARY",
+ "DATETIME",
+ "NUMBER",
+ "ROWID",
+ "STRING",
+]
diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py
new file mode 100644
index 0000000..a3a6ef8
--- /dev/null
+++ b/psycopg/psycopg/_adapters_map.py
@@ -0,0 +1,289 @@
+"""
+Mapping from types/oids to Dumpers/Loaders
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+from typing import cast, TYPE_CHECKING
+
+from . import pq
+from . import errors as e
+from .abc import Dumper, Loader
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+from ._typeinfo import TypesRegistry
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+RV = TypeVar("RV")
+
+
+class AdaptersMap:
+ r"""
+ Establish how types should be converted between Python and PostgreSQL in
+ an `~psycopg.abc.AdaptContext`.
+
+ `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
+ define how Python types are converted to PostgreSQL, and maps OIDs to
+ `~psycopg.adapt.Loader` classes to establish how query results are
+ converted to Python.
+
+ Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
+ types are converted in that context, exposed as the
+ `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
+ to customise adaptation in a context without changing separated contexts.
+
+ When a context is created from another context (for instance when a
+ `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
+ `!adapters` are used as template for the child's `!adapters`, so that every
+ cursor created from the same connection use the connection's types
+ configuration, but separate connections have independent mappings.
+
+ Once created, `!AdaptersMap` are independent. This means that objects
+ already created are not affected if a wider scope (e.g. the global one) is
+ changed.
+
+ The connections adapters are initialised using a global `!AdptersMap`
+ template, exposed as `psycopg.adapters`: changing such mapping allows to
+ customise the type mapping for every connections created afterwards.
+
+ The object can start empty or copy from another object of the same class.
+ Copies are copy-on-write: if the maps are updated make a copy. This way
+ extending e.g. global map by a connection or a connection map from a cursor
+ is cheap: a copy is only made on customisation.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ types: TypesRegistry
+
+ _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
+ _dumpers_by_oid: List[Dict[int, Type[Dumper]]]
+ _loaders: List[Dict[int, Type[Loader]]]
+
+ # Record if a dumper or loader has an optimised version.
+ _optimised: Dict[type, type] = {}
+
+ def __init__(
+ self,
+ template: Optional["AdaptersMap"] = None,
+ types: Optional[TypesRegistry] = None,
+ ):
+ if template:
+ self._dumpers = template._dumpers.copy()
+ self._own_dumpers = _dumpers_shared.copy()
+ template._own_dumpers = _dumpers_shared.copy()
+
+ self._dumpers_by_oid = template._dumpers_by_oid[:]
+ self._own_dumpers_by_oid = [False, False]
+ template._own_dumpers_by_oid = [False, False]
+
+ self._loaders = template._loaders[:]
+ self._own_loaders = [False, False]
+ template._own_loaders = [False, False]
+
+ self.types = TypesRegistry(template.types)
+
+ else:
+ self._dumpers = {fmt: {} for fmt in PyFormat}
+ self._own_dumpers = _dumpers_owned.copy()
+
+ self._dumpers_by_oid = [{}, {}]
+ self._own_dumpers_by_oid = [True, True]
+
+ self._loaders = [{}, {}]
+ self._own_loaders = [True, True]
+
+ self.types = types or TypesRegistry()
+
+ # implement the AdaptContext protocol too
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return None
+
+ def register_dumper(
+ self, cls: Union[type, str, None], dumper: Type[Dumper]
+ ) -> None:
+ """
+ Configure the context to use `!dumper` to convert objects of type `!cls`.
+
+ If two dumpers with different `~Dumper.format` are registered for the
+ same type, the last one registered will be chosen when the query
+ doesn't specify a format (i.e. when the value is used with a ``%s``
+ "`~PyFormat.AUTO`" placeholder).
+
+ :param cls: The type to manage.
+ :param dumper: The dumper to register for `!cls`.
+
+ If `!cls` is specified as string it will be lazy-loaded, so that it
+ will be possible to register it without importing it before. In this
+ case it should be the fully qualified name of the object (e.g.
+ ``"uuid.UUID"``).
+
+ If `!cls` is None, only use the dumper when looking up using
+ `get_dumper_by_oid()`, which happens when we know the Postgres type to
+ adapt to, but not the Python type that will be adapted (e.g. in COPY
+ after using `~psycopg.Copy.set_types()`).
+
+ """
+ if not (cls is None or isinstance(cls, (str, type))):
+ raise TypeError(
+ f"dumpers should be registered on classes, got {cls} instead"
+ )
+
+ if _psycopg:
+ dumper = self._get_optimised(dumper)
+
+ # Register the dumper both as its format and as auto
+ # so that the last dumper registered is used in auto (%s) format
+ if cls:
+ for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
+ if not self._own_dumpers[fmt]:
+ self._dumpers[fmt] = self._dumpers[fmt].copy()
+ self._own_dumpers[fmt] = True
+
+ self._dumpers[fmt][cls] = dumper
+
+ # Register the dumper by oid, if the oid of the dumper is fixed
+ if dumper.oid:
+ if not self._own_dumpers_by_oid[dumper.format]:
+ self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
+ dumper.format
+ ].copy()
+ self._own_dumpers_by_oid[dumper.format] = True
+
+ self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
+
+ def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
+ """
+ Configure the context to use `!loader` to convert data of oid `!oid`.
+
+ :param oid: The PostgreSQL OID or type name to manage.
+ :param loader: The loar to register for `!oid`.
+
+ If `oid` is specified as string, it refers to a type name, which is
+ looked up in the `types` registry. `
+
+ """
+ if isinstance(oid, str):
+ oid = self.types[oid].oid
+ if not isinstance(oid, int):
+ raise TypeError(f"loaders should be registered on oid, got {oid} instead")
+
+ if _psycopg:
+ loader = self._get_optimised(loader)
+
+ fmt = loader.format
+ if not self._own_loaders[fmt]:
+ self._loaders[fmt] = self._loaders[fmt].copy()
+ self._own_loaders[fmt] = True
+
+ self._loaders[fmt][oid] = loader
+
+ def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given type and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param cls: The class to adapt.
+ :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
+ use the last one of the dumpers registered on `!cls`.
+ """
+ try:
+ dmap = self._dumpers[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ # Look for the right class, including looking at superclasses
+ for scls in cls.__mro__:
+ if scls in dmap:
+ return dmap[scls]
+
+ # If the adapter is not found, look for its name as a string
+ fqn = scls.__module__ + "." + scls.__qualname__
+ if fqn in dmap:
+ # Replace the class name with the class itself
+ d = dmap[scls] = dmap.pop(fqn)
+ return d
+
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'"
+ f" (format: {PyFormat(format).name})"
+ )
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
+ """
+ Return the dumper class for the given oid and format.
+
+ Raise `~psycopg.ProgrammingError` if a class is not available.
+
+ :param oid: The oid of the type to dump to.
+ :param format: The format to dump to.
+ """
+ try:
+ dmap = self._dumpers_by_oid[format]
+ except KeyError:
+ raise ValueError(f"bad dumper format: {format}")
+
+ try:
+ return dmap[oid]
+ except KeyError:
+ info = self.types.get(oid)
+ if info:
+ msg = (
+ f"cannot find a dumper for type {info.name} (oid {oid})"
+ f" format {pq.Format(format).name}"
+ )
+ else:
+ msg = (
+ f"cannot find a dumper for unknown type with oid {oid}"
+ f" format {pq.Format(format).name}"
+ )
+ raise e.ProgrammingError(msg)
+
+ def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
+ """
+ Return the loader class for the given oid and format.
+
+ Return `!None` if not found.
+
+ :param oid: The oid of the type to load.
+ :param format: The format to load from.
+ """
+ return self._loaders[format].get(oid)
+
+ @classmethod
+ def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
+ """Return the optimised version of a Dumper or Loader class.
+
+ Return the input class itself if there is no optimised version.
+ """
+ try:
+ return self._optimised[cls]
+ except KeyError:
+ pass
+
+ # Check if the class comes from psycopg.types and there is a class
+ # with the same name in psycopg_c._psycopg.
+ from psycopg import types
+
+ if cls.__module__.startswith(types.__name__):
+ new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
+ if new:
+ self._optimised[cls] = new
+ return new
+
+ self._optimised[cls] = cls
+ return cls
+
+
+# Micro-optimization: copying these objects is faster than creating new dicts
+_dumpers_owned = dict.fromkeys(PyFormat, True)
+_dumpers_shared = dict.fromkeys(PyFormat, False)
diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py
new file mode 100644
index 0000000..288ef1b
--- /dev/null
+++ b/psycopg/psycopg/_cmodule.py
@@ -0,0 +1,24 @@
+"""
+Simplify access to the _psycopg module
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from typing import Optional
+
+from . import pq
+
+__version__: Optional[str] = None
+
+# Note: "c" must the first attempt so that mypy associates the variable the
+# right module interface. It will not result Optional, but hey.
+if pq.__impl__ == "c":
+ from psycopg_c import _psycopg as _psycopg
+ from psycopg_c import __version__ as __version__ # noqa: F401
+elif pq.__impl__ == "binary":
+ from psycopg_binary import _psycopg as _psycopg # type: ignore
+ from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
+elif pq.__impl__ == "python":
+ _psycopg = None # type: ignore
+else:
+ raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py
new file mode 100644
index 0000000..9e4e735
--- /dev/null
+++ b/psycopg/psycopg/_column.py
@@ -0,0 +1,143 @@
+"""
+The Column object in Cursor.description
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
+from operator import attrgetter
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor
+
+
+class ColumnData(NamedTuple):
+ ftype: int
+ fmod: int
+ fsize: int
+
+
+class Column(Sequence[Any]):
+
+ __module__ = "psycopg"
+
+ def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
+ res = cursor.pgresult
+ assert res
+
+ fname = res.fname(index)
+ if fname:
+ self._name = fname.decode(cursor._encoding)
+ else:
+ # COPY_OUT results have columns but no name
+ self._name = f"column_{index + 1}"
+
+ self._data = ColumnData(
+ ftype=res.ftype(index),
+ fmod=res.fmod(index),
+ fsize=res.fsize(index),
+ )
+ self._type = cursor.adapters.types.get(self._data.ftype)
+
+ _attrs = tuple(
+ attrgetter(attr)
+ for attr in """
+ name type_code display_size internal_size precision scale null_ok
+ """.split()
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<Column {self.name!r},"
+ f" type: {self._type_display()} (oid: {self.type_code})>"
+ )
+
+ def __len__(self) -> int:
+ return 7
+
+ def _type_display(self) -> str:
+ parts = []
+ parts.append(self._type.name if self._type else str(self.type_code))
+
+ mod1 = self.precision
+ if mod1 is None:
+ mod1 = self.display_size
+ if mod1:
+ parts.append(f"({mod1}")
+ if self.scale:
+ parts.append(f", {self.scale}")
+ parts.append(")")
+
+ if self._type and self.type_code == self._type.array_oid:
+ parts.append("[]")
+
+ return "".join(parts)
+
+ def __getitem__(self, index: Any) -> Any:
+ if isinstance(index, slice):
+ return tuple(getter(self) for getter in self._attrs[index])
+ else:
+ return self._attrs[index](self)
+
+ @property
+ def name(self) -> str:
+ """The name of the column."""
+ return self._name
+
+ @property
+ def type_code(self) -> int:
+ """The numeric OID of the column."""
+ return self._data.ftype
+
+ @property
+ def display_size(self) -> Optional[int]:
+ """The field size, for :sql:`varchar(n)`, None otherwise."""
+ if not self._type:
+ return None
+
+ if self._type.name in ("varchar", "char"):
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod - 4
+
+ return None
+
+ @property
+ def internal_size(self) -> Optional[int]:
+ """The internal field size for fixed-size types, None otherwise."""
+ fsize = self._data.fsize
+ return fsize if fsize >= 0 else None
+
+ @property
+ def precision(self) -> Optional[int]:
+ """The number of digits for fixed precision types."""
+ if not self._type:
+ return None
+
+ dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
+ if self._type.name == "numeric":
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod >> 16
+
+ elif self._type.name in dttypes:
+ fmod = self._data.fmod
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def scale(self) -> Optional[int]:
+ """The number of digits after the decimal point if available."""
+ if self._type and self._type.name == "numeric":
+ fmod = self._data.fmod - 4
+ if fmod >= 0:
+ return fmod & 0xFFFF
+
+ return None
+
+ @property
+ def null_ok(self) -> Optional[bool]:
+ """Always `!None`"""
+ return None
diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py
new file mode 100644
index 0000000..7dbae79
--- /dev/null
+++ b/psycopg/psycopg/_compat.py
@@ -0,0 +1,72 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
+
+# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
+# For this raisin it must be imported directly from typing_extension where used.
+# See https://github.com/microsoft/pyright/issues/4197
+from typing_extensions import TypeAlias
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+T = TypeVar("T")
+FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 8):
+ create_task = asyncio.create_task
+ from math import prod
+
+else:
+
+ def create_task(
+ coro: FutureT[T], name: Optional[str] = None
+ ) -> "asyncio.Future[T]":
+ return asyncio.create_task(coro)
+
+ from functools import reduce
+
+ def prod(seq: Sequence[int]) -> int:
+ return reduce(int.__mul__, seq, 1)
+
+
+if sys.version_info >= (3, 9):
+ from zoneinfo import ZoneInfo
+ from functools import cache
+ from collections import Counter, deque as Deque
+else:
+ from typing import Counter, Deque
+ from functools import lru_cache
+ from backports.zoneinfo import ZoneInfo
+
+ cache = lru_cache(maxsize=None)
+
+if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+else:
+ from typing_extensions import TypeGuard
+
+if sys.version_info >= (3, 11):
+ from typing import LiteralString
+else:
+ from typing_extensions import LiteralString
+
+__all__ = [
+ "Counter",
+ "Deque",
+ "LiteralString",
+ "Protocol",
+ "TypeGuard",
+ "ZoneInfo",
+ "cache",
+ "create_task",
+ "prod",
+]
diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py
new file mode 100644
index 0000000..1e146ba
--- /dev/null
+++ b/psycopg/psycopg/_dns.py
@@ -0,0 +1,223 @@
+# type: ignore # dnspython is currently optional and mypy fails if missing
+"""
+DNS query support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import os
+import re
+import warnings
+from random import randint
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
+from typing import TYPE_CHECKING
+from collections import defaultdict
+
+try:
+ from dns.resolver import Resolver, Cache
+ from dns.asyncresolver import Resolver as AsyncResolver
+ from dns.exception import DNSException
+except ImportError:
+ raise ImportError(
+ "the module psycopg._dns requires the package 'dnspython' installed"
+ )
+
+from . import errors as e
+from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_
+
+if TYPE_CHECKING:
+ from dns.rdtypes.IN.SRV import SRV
+
+resolver = Resolver()
+resolver.cache = Cache()
+
+async_resolver = AsyncResolver()
+async_resolver.cache = Cache()
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ .. deprecated:: 3.1
+ The use of this function is not necessary anymore, because
+ `psycopg.AsyncConnection.connect()` performs non-blocking name
+ resolution automatically.
+ """
+ warnings.warn(
+ "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
+ DeprecationWarning,
+ )
+ return await resolve_hostaddr_async_(params)
+
+
+def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply SRV DNS lookup as defined in :RFC:`2782`."""
+ return Rfc2782Resolver().resolve(params)
+
+
+async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Async equivalent of `resolve_srv()`."""
+ return await Rfc2782Resolver().resolve_async(params)
+
+
+class HostPort(NamedTuple):
+ host: str
+ port: str
+ totry: bool = False
+ target: Optional[str] = None
+
+
+class Rfc2782Resolver:
+ """Implement SRV RR Resolution as per RFC 2782
+
+ The class is organised to minimise code duplication between the sync and
+ the async paths.
+ """
+
+ re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
+
+ def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(self._resolve_srv(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(await self._resolve_srv_async(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
+ """
+ Return the list of host, and for each host if SRV lookup must be tried.
+
+ Return an empty list if no lookup is requested.
+ """
+ # If hostaddr is defined don't do any resolution.
+ if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
+ return []
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",")
+
+ if len(ports_in) == 1:
+ # If only one port is specified, it applies to all the hosts.
+ ports_in *= len(hosts_in)
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+
+ out = []
+ srv_found = False
+ for host, port in zip(hosts_in, ports_in):
+ m = self.re_srv_rr.match(host)
+ if m or port.lower() == "srv":
+ srv_found = True
+ target = m.group("target") if m else None
+ hp = HostPort(host=host, port=port, totry=True, target=target)
+ else:
+ hp = HostPort(host=host, port=port)
+ out.append(hp)
+
+ return out if srv_found else []
+
+ def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = await async_resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ def _get_solved_entries(
+ self, hp: HostPort, entries: "Sequence[SRV]"
+ ) -> List[HostPort]:
+ if not entries:
+ # No SRV entry found. Delegate the libpq a QNAME=target lookup
+ if hp.target and hp.port.lower() != "srv":
+ return [HostPort(host=hp.target, port=hp.port)]
+ else:
+ return []
+
+ # If there is precisely one SRV RR, and its Target is "." (the root
+ # domain), abort.
+ if len(entries) == 1 and str(entries[0].target) == ".":
+ return []
+
+ return [
+ HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
+ for entry in self.sort_rfc2782(entries)
+ ]
+
+ def _return_params(
+ self, params: Dict[str, Any], hps: List[HostPort]
+ ) -> Dict[str, Any]:
+ if not hps:
+ # Nothing found, we ended up with an empty list
+ raise e.OperationalError("no host found after SRV RR lookup")
+
+ out = params.copy()
+ out["host"] = ",".join(hp.host for hp in hps)
+ out["port"] = ",".join(str(hp.port) for hp in hps)
+ return out
+
+ def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
+ """
+ Implement the priority/weight ordering defined in RFC 2782.
+ """
+ # Divide the entries by priority:
+ priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
+ out: "List[SRV]" = []
+ for entry in ans:
+ priorities[entry.priority].append(entry)
+
+ for pri, entries in sorted(priorities.items()):
+ if len(entries) == 1:
+ out.append(entries[0])
+ continue
+
+ entries.sort(key=lambda ent: ent.weight)
+ total_weight = sum(ent.weight for ent in entries)
+ while entries:
+ r = randint(0, total_weight)
+ csum = 0
+ for i, ent in enumerate(entries):
+ csum += ent.weight
+ if csum >= r:
+ break
+ out.append(ent)
+ total_weight -= ent.weight
+ del entries[i]
+
+ return out
diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py
new file mode 100644
index 0000000..c584b26
--- /dev/null
+++ b/psycopg/psycopg/_encodings.py
@@ -0,0 +1,170 @@
+"""
+Mappings between PostgreSQL and Python encodings.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import string
+import codecs
+from typing import Any, Dict, Optional, TYPE_CHECKING
+
+from .pq._enums import ConnStatus
+from .errors import NotSupportedError
+from ._compat import cache
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+ from .connection import BaseConnection
+
+OK = ConnStatus.OK
+
+
+_py_codecs = {
+ "BIG5": "big5",
+ "EUC_CN": "gb2312",
+ "EUC_JIS_2004": "euc_jis_2004",
+ "EUC_JP": "euc_jp",
+ "EUC_KR": "euc_kr",
+ # "EUC_TW": not available in Python
+ "GB18030": "gb18030",
+ "GBK": "gbk",
+ "ISO_8859_5": "iso8859-5",
+ "ISO_8859_6": "iso8859-6",
+ "ISO_8859_7": "iso8859-7",
+ "ISO_8859_8": "iso8859-8",
+ "JOHAB": "johab",
+ "KOI8R": "koi8-r",
+ "KOI8U": "koi8-u",
+ "LATIN1": "iso8859-1",
+ "LATIN10": "iso8859-16",
+ "LATIN2": "iso8859-2",
+ "LATIN3": "iso8859-3",
+ "LATIN4": "iso8859-4",
+ "LATIN5": "iso8859-9",
+ "LATIN6": "iso8859-10",
+ "LATIN7": "iso8859-13",
+ "LATIN8": "iso8859-14",
+ "LATIN9": "iso8859-15",
+ # "MULE_INTERNAL": not available in Python
+ "SHIFT_JIS_2004": "shift_jis_2004",
+ "SJIS": "shift_jis",
+ # this actually means no encoding, see PostgreSQL docs
+ # it is special-cased by the text loader.
+ "SQL_ASCII": "ascii",
+ "UHC": "cp949",
+ "UTF8": "utf-8",
+ "WIN1250": "cp1250",
+ "WIN1251": "cp1251",
+ "WIN1252": "cp1252",
+ "WIN1253": "cp1253",
+ "WIN1254": "cp1254",
+ "WIN1255": "cp1255",
+ "WIN1256": "cp1256",
+ "WIN1257": "cp1257",
+ "WIN1258": "cp1258",
+ "WIN866": "cp866",
+ "WIN874": "cp874",
+}
+
+py_codecs: Dict[bytes, str] = {}
+py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
+
+# Add an alias without underscore, for lenient lookups
+py_codecs.update(
+ (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
+)
+
+pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
+
+
+def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
+ """
+ Return the Python encoding name of a psycopg connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if not conn or conn.closed:
+ return "utf-8"
+
+ pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def pgconn_encoding(pgconn: "PGconn") -> str:
+ """
+ Return the Python encoding name of a libpq connection.
+
+ Default to utf8 if the connection has no encoding info.
+ """
+ if pgconn.status != OK:
+ return "utf-8"
+
+ pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
+ return pg2pyenc(pgenc)
+
+
+def conninfo_encoding(conninfo: str) -> str:
+ """
+ Return the Python encoding name passed in a conninfo string. Default to utf8.
+
+ Because the input is likely to come from the user and not normalised by the
+ server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
+ """
+ from .conninfo import conninfo_to_dict
+
+ params = conninfo_to_dict(conninfo)
+ pgenc = params.get("client_encoding")
+ if pgenc:
+ try:
+ return pg2pyenc(pgenc.encode())
+ except NotSupportedError:
+ pass
+
+ return "utf-8"
+
+
+@cache
+def py2pgenc(name: str) -> bytes:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise LookupError if the Python encoding is unknown.
+ """
+ return pg_codecs[codecs.lookup(name).name]
+
+
+@cache
+def pg2pyenc(name: bytes) -> str:
+ """Convert a Python encoding name to PostgreSQL encoding name.
+
+ Raise NotSupportedError if the PostgreSQL encoding is not supported by
+ Python.
+ """
+ try:
+ return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
+ except KeyError:
+ sname = name.decode("utf8", "replace")
+ raise NotSupportedError(f"codec not available in Python: {sname!r}")
+
+
+def _as_python_identifier(s: str, prefix: str = "f") -> str:
+ """
+ Reduce a string to a valid Python identifier.
+
+ Replace all non-valid chars with '_' and prefix the value with `!prefix` if
+ the first letter is an '_'.
+ """
+ if not s.isidentifier():
+ if s[0] in "1234567890":
+ s = prefix + s
+ if not s.isidentifier():
+ s = _re_clean.sub("_", s)
+ # namedtuple fields cannot start with underscore. So...
+ if s[0] == "_":
+ s = prefix + s
+ return s
+
+
+_re_clean = re.compile(
+ f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
+)
diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py
new file mode 100644
index 0000000..a7cb78d
--- /dev/null
+++ b/psycopg/psycopg/_enums.py
@@ -0,0 +1,79 @@
+"""
+Enum values for psycopg
+
+These values are defined by us and are not necessarily dependent on
+libpq-defined enums.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import Enum, IntEnum
+from selectors import EVENT_READ, EVENT_WRITE
+
+from . import pq
+
+
+class Wait(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class Ready(IntEnum):
+ R = EVENT_READ
+ W = EVENT_WRITE
+ RW = EVENT_READ | EVENT_WRITE
+
+
+class PyFormat(str, Enum):
+ """
+ Enum representing the format wanted for a query argument.
+
+ The value `AUTO` allows psycopg to choose the best format for a certain
+ parameter.
+ """
+
+ __module__ = "psycopg.adapt"
+
+ AUTO = "s"
+ """Automatically chosen (``%s`` placeholder)."""
+ TEXT = "t"
+ """Text parameter (``%t`` placeholder)."""
+ BINARY = "b"
+ """Binary parameter (``%b`` placeholder)."""
+
+ @classmethod
+ def from_pq(cls, fmt: pq.Format) -> "PyFormat":
+ return _pg2py[fmt]
+
+ @classmethod
+ def as_pq(cls, fmt: "PyFormat") -> pq.Format:
+ return _py2pg[fmt]
+
+
+class IsolationLevel(IntEnum):
+ """
+ Enum representing the isolation level for a transaction.
+ """
+
+ __module__ = "psycopg"
+
+ READ_UNCOMMITTED = 1
+ """:sql:`READ UNCOMMITTED` isolation level."""
+ READ_COMMITTED = 2
+ """:sql:`READ COMMITTED` isolation level."""
+ REPEATABLE_READ = 3
+ """:sql:`REPEATABLE READ` isolation level."""
+ SERIALIZABLE = 4
+ """:sql:`SERIALIZABLE` isolation level."""
+
+
+_py2pg = {
+ PyFormat.TEXT: pq.Format.TEXT,
+ PyFormat.BINARY: pq.Format.BINARY,
+}
+
+_pg2py = {
+ pq.Format.TEXT: PyFormat.TEXT,
+ pq.Format.BINARY: PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py
new file mode 100644
index 0000000..c818d86
--- /dev/null
+++ b/psycopg/psycopg/_pipeline.py
@@ -0,0 +1,288 @@
+"""
+commands pipeline management
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+from types import TracebackType
+from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from .abc import PipelineCommand, PQGen
+from ._compat import Deque
+from ._encodings import pgconn_encoding
+from ._preparing import Key, Prepare
+from .generators import pipeline_communicate, fetch_many, send
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+ from .cursor import BaseCursor
+ from .connection import BaseConnection, Connection
+ from .connection_async import AsyncConnection
+
+
+PendingResult: TypeAlias = Union[
+ None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
+]
+
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+BAD = pq.ConnStatus.BAD
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+logger = logging.getLogger("psycopg")
+
+
+class BasePipeline:
+
+ command_queue: Deque[PipelineCommand]
+ result_queue: Deque[PendingResult]
+ _is_supported: Optional[bool] = None
+
+ def __init__(self, conn: "BaseConnection[Any]") -> None:
+ self._conn = conn
+ self.pgconn = conn.pgconn
+ self.command_queue = Deque[PipelineCommand]()
+ self.result_queue = Deque[PendingResult]()
+ self.level = 0
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._conn.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def status(self) -> pq.PipelineStatus:
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ @classmethod
+ def is_supported(cls) -> bool:
+ """Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
+ if BasePipeline._is_supported is None:
+ BasePipeline._is_supported = not cls._not_supported_reason()
+ return BasePipeline._is_supported
+
+ @classmethod
+ def _not_supported_reason(cls) -> str:
+ """Return the reason why the pipeline mode is not supported.
+
+ Return an empty string if pipeline mode is supported.
+ """
+ # Support only depends on the libpq functions available in the pq
+ # wrapper, not on the database version.
+ if pq.version() < 140000:
+ return (
+ f"libpq too old {pq.version()};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ if pq.__build_version__ < 140000:
+ return (
+ f"libpq too old: module built for {pq.__build_version__};"
+ " v14 or greater required for pipeline mode"
+ )
+
+ return ""
+
+ def _enter_gen(self) -> PQGen[None]:
+ if not self.is_supported():
+ raise e.NotSupportedError(
+ f"pipeline mode not supported: {self._not_supported_reason()}"
+ )
+ if self.level == 0:
+ self.pgconn.enter_pipeline_mode()
+ elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
+ # Nested pipeline case.
+ # Transaction might be ACTIVE when the pipeline uses an "implicit
+ # transaction", typically in autocommit mode. But when entering a
+ # Psycopg transaction(), we expect the IDLE state. By sync()-ing,
+ # we make sure all previous commands are completed and the
+ # transaction gets back to IDLE.
+ yield from self._sync_gen()
+ self.level += 1
+
+ def _exit(self, exc: Optional[BaseException]) -> None:
+ self.level -= 1
+ if self.level == 0 and self.pgconn.status != BAD:
+ try:
+ self.pgconn.exit_pipeline_mode()
+ except e.OperationalError as exc2:
+ # Notice that this error might be pretty irrecoverable. It
+ # happens on COPY, for instance: even if sync succeeds, exiting
+ # fails with "cannot exit pipeline mode with uncollected results"
+ if exc:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+
+ def _sync_gen(self) -> PQGen[None]:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ yield from self._fetch_gen(flush=False)
+
+ def _exit_gen(self) -> PQGen[None]:
+ """
+ Exit current pipeline by sending a Sync and fetch back all remaining results.
+ """
+ try:
+ self._enqueue_sync()
+ yield from self._communicate_gen()
+ finally:
+ # No need to force flush since we emitted a sync just before.
+ yield from self._fetch_gen(flush=False)
+
+ def _communicate_gen(self) -> PQGen[None]:
+ """Communicate with pipeline to send commands and possibly fetch
+ results, which are then processed.
+ """
+ fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
+ to_process = [(self.result_queue.popleft(), results) for results in fetched]
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
+ """Fetch available results from the connection and process them with
+ pipeline queued items.
+
+ If 'flush' is True, a PQsendFlushRequest() is issued in order to make
+ sure results can be fetched. Otherwise, the caller may emit a
+ PQpipelineSync() call to ensure the output buffer gets flushed before
+ fetching.
+ """
+ if not self.result_queue:
+ return
+
+ if flush:
+ self.pgconn.send_flush_request()
+ yield from send(self.pgconn)
+
+ to_process = []
+ while self.result_queue:
+ results = yield from fetch_many(self.pgconn)
+ if not results:
+ # No more results to fetch, but there may still be pending
+ # commands.
+ break
+ queued = self.result_queue.popleft()
+ to_process.append((queued, results))
+
+ for queued, results in to_process:
+ self._process_results(queued, results)
+
+ def _process_results(
+ self, queued: PendingResult, results: List["PGresult"]
+ ) -> None:
+ """Process a results set fetched from the current pipeline.
+
+ This matches 'results' with its respective element in the pipeline
+ queue. For commands (None value in the pipeline queue), results are
+ checked directly. For prepare statement creation requests, update the
+ cache. Otherwise, results are attached to their respective cursor.
+ """
+ if queued is None:
+ (result,) = results
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ elif result.status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ cursor, prepinfo = queued
+ cursor._set_results_from_pipeline(results)
+ if prepinfo:
+ key, prep, name = prepinfo
+ # Update the prepare state of the query.
+ cursor._conn._prepared.validate(key, prep, name, results)
+
+ def _enqueue_sync(self) -> None:
+ """Enqueue a PQpipelineSync() command."""
+ self.command_queue.append(self.pgconn.pipeline_sync)
+ self.result_queue.append(None)
+
+
+class Pipeline(BasePipeline):
+ """Handler for connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "Connection[Any]"
+ _Self = TypeVar("_Self", bound="Pipeline")
+
+ def __init__(self, conn: "Connection[Any]") -> None:
+ super().__init__(conn)
+
+ def sync(self) -> None:
+ """Sync the pipeline, send any pending command and receive and process
+ all available results.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
+
+
+class AsyncPipeline(BasePipeline):
+ """Handler for async connection in pipeline mode."""
+
+ __module__ = "psycopg"
+ _conn: "AsyncConnection[Any]"
+ _Self = TypeVar("_Self", bound="AsyncPipeline")
+
+ def __init__(self, conn: "AsyncConnection[Any]") -> None:
+ super().__init__(conn)
+
+ async def sync(self) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._sync_gen())
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._exit_gen())
+ except Exception as exc2:
+ # Don't clobber an exception raised in the block with this one
+ if exc_val:
+ logger.warning("error ignored terminating %r: %s", self, exc2)
+ else:
+ raise exc2.with_traceback(None)
+ finally:
+ self._exit(exc_val)
diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py
new file mode 100644
index 0000000..f60c0cb
--- /dev/null
+++ b/psycopg/psycopg/_preparing.py
@@ -0,0 +1,198 @@
+"""
+Support for prepared statements
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, auto
+from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
+from collections import OrderedDict
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._compat import Deque
+from ._queries import PostgresQuery
+
+if TYPE_CHECKING:
+ from .pq.abc import PGresult
+
+Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+
+class Prepare(IntEnum):
+ NO = auto()
+ YES = auto()
+ SHOULD = auto()
+
+
+class PrepareManager:
+ # Number of times a query is executed before it is prepared.
+ prepare_threshold: Optional[int] = 5
+
+ # Maximum number of prepared statements on the connection.
+ prepared_max: int = 100
+
+ def __init__(self) -> None:
+ # Map (query, types) to the number of times the query was seen.
+ self._counts: OrderedDict[Key, int] = OrderedDict()
+
+ # Map (query, types) to the name of the statement if prepared.
+ self._names: OrderedDict[Key, bytes] = OrderedDict()
+
+ # Counter to generate prepared statements names
+ self._prepared_idx = 0
+
+ self._maint_commands = Deque[bytes]()
+
+ @staticmethod
+ def key(query: PostgresQuery) -> Key:
+ return (query.query, query.types)
+
+ def get(
+ self, query: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ """
+ Check if a query is prepared, tell back whether to prepare it.
+ """
+ if prepare is False or self.prepare_threshold is None:
+ # The user doesn't want this query to be prepared
+ return Prepare.NO, b""
+
+ key = self.key(query)
+ name = self._names.get(key)
+ if name:
+ # The query was already prepared in this session
+ return Prepare.YES, name
+
+ count = self._counts.get(key, 0)
+ if count >= self.prepare_threshold or prepare:
+ # The query has been executed enough times and needs to be prepared
+ name = f"_pg3_{self._prepared_idx}".encode()
+ self._prepared_idx += 1
+ return Prepare.SHOULD, name
+ else:
+ # The query is not to be prepared yet
+ return Prepare.NO, b""
+
+ def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
+ """Check if we need to discard our entire state: it should happen on
+ rollback or on dropping objects, because the same object may get
+ recreated and postgres would fail internal lookups.
+ """
+ if self._names or prep == Prepare.SHOULD:
+ for result in results:
+ if result.status != COMMAND_OK:
+ continue
+ cmdstat = result.command_status
+ if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
+ return self.clear()
+ return False
+
+ @staticmethod
+ def _check_results(results: Sequence["PGresult"]) -> bool:
+ """Return False if 'results' are invalid for prepared statement cache."""
+ if len(results) != 1:
+ # We cannot prepare a multiple statement
+ return False
+
+ status = results[0].status
+ if COMMAND_OK != status != TUPLES_OK:
+ # We don't prepare failed queries or other weird results
+ return False
+
+ return True
+
+ def _rotate(self) -> None:
+ """Evict an old value from the cache.
+
+ If it was prepared, deallocate it. Do it only once: if the cache was
+ resized, deallocate gradually.
+ """
+ if len(self._counts) > self.prepared_max:
+ self._counts.popitem(last=False)
+
+ if len(self._names) > self.prepared_max:
+ name = self._names.popitem(last=False)[1]
+ self._maint_commands.append(b"DEALLOCATE " + name)
+
+ def maybe_add_to_cache(
+ self, query: PostgresQuery, prep: Prepare, name: bytes
+ ) -> Optional[Key]:
+ """Handle 'query' for possible addition to the cache.
+
+ If a new entry has been added, return its key. Return None otherwise
+ (meaning the query is already in cache or cache is not enabled).
+
+ Note: This method is only called in pipeline mode.
+ """
+ # don't do anything if prepared statements are disabled
+ if self.prepare_threshold is None:
+ return None
+
+ key = self.key(query)
+ if key in self._counts:
+ if prep is Prepare.SHOULD:
+ del self._counts[key]
+ self._names[key] = name
+ else:
+ self._counts[key] += 1
+ self._counts.move_to_end(key)
+ return None
+
+ elif key in self._names:
+ self._names.move_to_end(key)
+ return None
+
+ else:
+ if prep is Prepare.SHOULD:
+ self._names[key] = name
+ else:
+ self._counts[key] = 1
+ return key
+
+ def validate(
+ self,
+ key: Key,
+ prep: Prepare,
+ name: bytes,
+ results: Sequence["PGresult"],
+ ) -> None:
+ """Validate cached entry with 'key' by checking query 'results'.
+
+ Possibly return a command to perform maintenance on database side.
+
+ Note: this method is only called in pipeline mode.
+ """
+ if self._should_discard(prep, results):
+ return
+
+ if not self._check_results(results):
+ self._names.pop(key, None)
+ self._counts.pop(key, None)
+ else:
+ self._rotate()
+
+ def clear(self) -> bool:
+ """Clear the cache of the maintenance commands.
+
+ Clear the internal state and prepare a command to clear the state of
+ the server.
+ """
+ self._counts.clear()
+ if self._names:
+ self._names.clear()
+ self._maint_commands.clear()
+ self._maint_commands.append(b"DEALLOCATE ALL")
+ return True
+ else:
+ return False
+
+ def get_maintenance_commands(self) -> Iterator[bytes]:
+ """
+ Iterate over the commands needed to align the server state to our state
+ """
+ while self._maint_commands:
+ yield self._maint_commands.popleft()
diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py
new file mode 100644
index 0000000..2a7554c
--- /dev/null
+++ b/psycopg/psycopg/_queries.py
@@ -0,0 +1,375 @@
+"""
+Utility module to manipulate queries
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
+from typing import Sequence, Tuple, Union, TYPE_CHECKING
+from functools import lru_cache
+
+from . import pq
+from . import errors as e
+from .sql import Composable
+from .abc import Buffer, Query, Params
+from ._enums import PyFormat
+from ._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+
+
+class QueryPart(NamedTuple):
+ pre: bytes
+ item: Union[int, str]
+ format: PyFormat
+
+
+class PostgresQuery:
+ """
+ Helper to convert a Python query and parameters into Postgres format.
+ """
+
+ __slots__ = """
+ query params types formats
+ _tx _want_formats _parts _encoding _order
+ """.split()
+
+ def __init__(self, transformer: "Transformer"):
+ self._tx = transformer
+
+ self.params: Optional[Sequence[Optional[Buffer]]] = None
+ # these are tuples so they can be used as keys e.g. in prepared stmts
+ self.types: Tuple[int, ...] = ()
+
+ # The format requested by the user and the ones to really pass Postgres
+ self._want_formats: Optional[List[PyFormat]] = None
+ self.formats: Optional[Sequence[pq.Format]] = None
+
+ self._encoding = conn_encoding(transformer.connection)
+ self._parts: List[QueryPart]
+ self.query = b""
+ self._order: Optional[List[str]] = None
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (
+ self.query,
+ self._want_formats,
+ self._order,
+ self._parts,
+ ) = _query2pg(bquery, self._encoding)
+ else:
+ self.query = bquery
+ self._want_formats = self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ assert self._want_formats is not None
+ self.params = self._tx.dump_sequence(params, self._want_formats)
+ self.types = self._tx.types or ()
+ self.formats = self._tx.formats
+ else:
+ self.params = None
+ self.types = ()
+ self.formats = None
+
+
+class PostgresClientQuery(PostgresQuery):
+ """
+ PostgresQuery subclass merging query and arguments client-side.
+ """
+
+ __slots__ = ("template",)
+
+ def convert(self, query: Query, vars: Optional[Params]) -> None:
+ """
+ Set up the query and parameters to convert.
+
+ The results of this function can be obtained accessing the object
+ attributes (`query`, `params`, `types`, `formats`).
+ """
+ if isinstance(query, str):
+ bquery = query.encode(self._encoding)
+ elif isinstance(query, Composable):
+ bquery = query.as_bytes(self._tx)
+ else:
+ bquery = query
+
+ if vars is not None:
+ (self.template, self._order, self._parts) = _query2pg_client(
+ bquery, self._encoding
+ )
+ else:
+ self.query = bquery
+ self._order = None
+
+ self.dump(vars)
+
+ def dump(self, vars: Optional[Params]) -> None:
+ """
+ Process a new set of variables on the query processed by `convert()`.
+
+ This method updates `params` and `types`.
+ """
+ if vars is not None:
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
+ self.params = tuple(
+ self._tx.as_literal(p) if p is not None else b"NULL" for p in params
+ )
+ self.query = self.template % self.params
+ else:
+ self.params = None
+
+
+@lru_cache()
+def _query2pg(
+ query: bytes, encoding: str
+) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into something Postgres understands.
+
+ - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+ format (``$1``, ``$2``)
+ - placeholders can be %s, %t, or %b (auto, text or binary)
+ - return ``query`` (bytes), ``formats`` (list of formats) ``order``
+ (sequence of names used in the query, in the position they appear)
+ ``parts`` (splits of queries and placeholders).
+ """
+ parts = _split_query(query, encoding)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+ formats = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"$%d" % (part.item + 1))
+ formats.append(part.format)
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"$%d" % (len(seen) + 1)
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ formats.append(part.format)
+ else:
+ if seen[part.item][1] != part.format:
+ raise e.ProgrammingError(
+ f"placeholder '{part.item}' cannot have different formats"
+ )
+ chunks.append(seen[part.item][0])
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), formats, order, parts
+
+
+@lru_cache()
+def _query2pg_client(
+ query: bytes, encoding: str
+) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
+ """
+ Convert Python query and params into a template to perform client-side binding
+ """
+ parts = _split_query(query, encoding, collapse_double_percent=False)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+
+ if isinstance(parts[0].item, int):
+ for part in parts[:-1]:
+ assert isinstance(part.item, int)
+ chunks.append(part.pre)
+ chunks.append(b"%s")
+
+ elif isinstance(parts[0].item, str):
+ seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+ order = []
+ for part in parts[:-1]:
+ assert isinstance(part.item, str)
+ chunks.append(part.pre)
+ if part.item not in seen:
+ ph = b"%s"
+ seen[part.item] = (ph, part.format)
+ order.append(part.item)
+ chunks.append(ph)
+ else:
+ chunks.append(seen[part.item][0])
+ order.append(part.item)
+
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), order, parts
+
+
+def _validate_and_reorder_params(
+ parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+) -> Sequence[Any]:
+ """
+ Verify the compatibility between a query and a set of params.
+ """
+ # Try concrete types, then abstract types
+ t = type(vars)
+ if t is list or t is tuple:
+ sequence = True
+ elif t is dict:
+ sequence = False
+ elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+ sequence = True
+ elif isinstance(vars, Mapping):
+ sequence = False
+ else:
+ raise TypeError(
+ "query parameters should be a sequence or a mapping,"
+ f" got {type(vars).__name__}"
+ )
+
+ if sequence:
+ if len(vars) != len(parts) - 1:
+ raise e.ProgrammingError(
+ f"the query has {len(parts) - 1} placeholders but"
+ f" {len(vars)} parameters were passed"
+ )
+ if vars and not isinstance(parts[0].item, int):
+ raise TypeError("named placeholders require a mapping of parameters")
+ return vars # type: ignore[return-value]
+
+ else:
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+ raise TypeError(
+ "positional placeholders (%s) require a sequence of parameters"
+ )
+ try:
+ return [vars[item] for item in order or ()] # type: ignore[call-overload]
+ except KeyError:
+ raise e.ProgrammingError(
+ "query parameter missing:"
+ f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+ )
+
+
+_re_placeholder = re.compile(
+ rb"""(?x)
+ % # a literal %
+ (?:
+ (?:
+ \( ([^)]+) \) # or a name in (braces)
+ . # followed by a format
+ )
+ |
+ (?:.) # or any char, really
+ )
+ """
+)
+
+
+def _split_query(
+ query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
+) -> List[QueryPart]:
+ parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
+ cur = 0
+
+ # pairs [(fragment, match], with the last match None
+ m = None
+ for m in _re_placeholder.finditer(query):
+ pre = query[cur : m.span(0)[0]]
+ parts.append((pre, m))
+ cur = m.span(0)[1]
+ if m:
+ parts.append((query[cur:], None))
+ else:
+ parts.append((query, None))
+
+ rv = []
+
+ # drop the "%%", validate
+ i = 0
+ phtype = None
+ while i < len(parts):
+ pre, m = parts[i]
+ if m is None:
+ # last part
+ rv.append(QueryPart(pre, 0, PyFormat.AUTO))
+ break
+
+ ph = m.group(0)
+ if ph == b"%%":
+ # unescape '%%' to '%' if necessary, then merge the parts
+ if collapse_double_percent:
+ ph = b"%"
+ pre1, m1 = parts[i + 1]
+ parts[i + 1] = (pre + ph + pre1, m1)
+ del parts[i]
+ continue
+
+ if ph == b"%(":
+ raise e.ProgrammingError(
+ "incomplete placeholder:"
+ f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
+ )
+ elif ph == b"% ":
+ # explicit messasge for a typical error
+ raise e.ProgrammingError(
+ "incomplete placeholder: '%'; if you want to use '%' as an"
+ " operator you can double it up, i.e. use '%%'"
+ )
+ elif ph[-1:] not in b"sbt":
+ raise e.ProgrammingError(
+ "only '%s', '%b', '%t' are allowed as placeholders, got"
+ f" '{m.group(0).decode(encoding)}'"
+ )
+
+ # Index or name
+ item: Union[int, str]
+ item = m.group(1).decode(encoding) if m.group(1) else i
+
+ if not phtype:
+ phtype = type(item)
+ elif phtype is not type(item):
+ raise e.ProgrammingError(
+ "positional and named placeholders cannot be mixed"
+ )
+
+ format = _ph_to_fmt[ph[-1:]]
+ rv.append(QueryPart(pre, item, format))
+ i += 1
+
+ return rv
+
+
+_ph_to_fmt = {
+ b"s": PyFormat.AUTO,
+ b"t": PyFormat.TEXT,
+ b"b": PyFormat.BINARY,
+}
diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py
new file mode 100644
index 0000000..28a6084
--- /dev/null
+++ b/psycopg/psycopg/_struct.py
@@ -0,0 +1,57 @@
+"""
+Utility functions to deal with binary structs.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from typing import Callable, cast, Optional, Tuple
+from typing_extensions import TypeAlias
+
+from .abc import Buffer
+from . import errors as e
+from ._compat import Protocol
+
+PackInt: TypeAlias = Callable[[int], bytes]
+UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
+PackFloat: TypeAlias = Callable[[float], bytes]
+UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
+
+
+class UnpackLen(Protocol):
+ def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
+ ...
+
+
+pack_int2 = cast(PackInt, struct.Struct("!h").pack)
+pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
+pack_int4 = cast(PackInt, struct.Struct("!i").pack)
+pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
+pack_int8 = cast(PackInt, struct.Struct("!q").pack)
+pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
+pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
+
+unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
+unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
+unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
+unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
+unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
+unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
+unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
+
+_struct_len = struct.Struct("!i")
+pack_len = cast(Callable[[int], bytes], _struct_len.pack)
+unpack_len = cast(UnpackLen, _struct_len.unpack_from)
+
+
+def pack_float4_bug_304(x: float) -> bytes:
+ raise e.InterfaceError(
+ "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
+ " and psycopg-binary packages are not affected by this issue."
+ " See https://github.com/psycopg/psycopg/issues/304"
+ )
+
+
+# If issue #304 is detected, raise an error instead of dumping wrong data.
+if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
+ pack_float4 = pack_float4_bug_304
diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py
new file mode 100644
index 0000000..3528188
--- /dev/null
+++ b/psycopg/psycopg/_tpc.py
@@ -0,0 +1,116 @@
+"""
+psycopg two-phase commit support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+import datetime as dt
+from base64 import b64encode, b64decode
+from typing import Optional, Union
+from dataclasses import dataclass, replace
+
+_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
+
+
+@dataclass(frozen=True)
+class Xid:
+ """A two-phase commit transaction identifier.
+
+ The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
+ `bqual`).
+
+ """
+
+ format_id: Optional[int]
+ gtrid: str
+ bqual: Optional[str]
+ prepared: Optional[dt.datetime] = None
+ owner: Optional[str] = None
+ database: Optional[str] = None
+
+ @classmethod
+ def from_string(cls, s: str) -> "Xid":
+ """Try to parse an XA triple from the string.
+
+ This may fail for several reasons. In such case return an unparsed Xid.
+ """
+ try:
+ return cls._parse_string(s)
+ except Exception:
+ return Xid(None, s, None)
+
+ def __str__(self) -> str:
+ return self._as_tid()
+
+ def __len__(self) -> int:
+ return 3
+
+ def __getitem__(self, index: int) -> Union[int, str, None]:
+ return (self.format_id, self.gtrid, self.bqual)[index]
+
+ @classmethod
+ def _parse_string(cls, s: str) -> "Xid":
+ m = _re_xid.match(s)
+ if not m:
+ raise ValueError("bad Xid format")
+
+ format_id = int(m.group(1))
+ gtrid = b64decode(m.group(2)).decode()
+ bqual = b64decode(m.group(3)).decode()
+ return cls.from_parts(format_id, gtrid, bqual)
+
+ @classmethod
+ def from_parts(
+ cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
+ ) -> "Xid":
+ if format_id is not None:
+ if bqual is None:
+ raise TypeError("if format_id is specified, bqual must be too")
+ if not 0 <= format_id < 0x80000000:
+ raise ValueError("format_id must be a non-negative 32-bit integer")
+ if len(bqual) > 64:
+ raise ValueError("bqual must be not longer than 64 chars")
+ if len(gtrid) > 64:
+ raise ValueError("gtrid must be not longer than 64 chars")
+
+ elif bqual is None:
+ raise TypeError("if format_id is None, bqual must be None too")
+
+ return Xid(format_id, gtrid, bqual)
+
+ def _as_tid(self) -> str:
+ """
+ Return the PostgreSQL transaction_id for this XA xid.
+
+ PostgreSQL wants just a string, while the DBAPI supports the XA
+ standard and thus a triple. We use the same conversion algorithm
+ implemented by JDBC in order to allow some form of interoperation.
+
+ see also: the pgjdbc implementation
+ http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
+ postgresql/xa/RecoveredXid.java?rev=1.2
+ """
+ if self.format_id is None or self.bqual is None:
+ # Unparsed xid: return the gtrid.
+ return self.gtrid
+
+ # XA xid: mash together the components.
+ egtrid = b64encode(self.gtrid.encode()).decode()
+ ebqual = b64encode(self.bqual.encode()).decode()
+
+ return f"{self.format_id}_{egtrid}_{ebqual}"
+
+ @classmethod
+ def _get_recover_query(cls) -> str:
+ return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
+
+ @classmethod
+ def _from_record(
+ cls, gid: str, prepared: dt.datetime, owner: str, database: str
+ ) -> "Xid":
+ xid = Xid.from_string(gid)
+ return replace(xid, prepared=prepared, owner=owner, database=database)
+
+
+Xid.__module__ = "psycopg"
diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py
new file mode 100644
index 0000000..19bd6ae
--- /dev/null
+++ b/psycopg/psycopg/_transform.py
@@ -0,0 +1,350 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+from typing import DefaultDict, TYPE_CHECKING
+from collections import defaultdict
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import postgres
+from . import errors as e
+from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
+from .rows import Row, RowMaker
+from .postgres import INVALID_OID, TEXT_OID
+from ._encodings import pgconn_encoding
+
+if TYPE_CHECKING:
+ from .abc import Dumper, Loader
+ from .adapt import AdaptersMap
+ from .pq.abc import PGresult
+ from .connection import BaseConnection
+
+DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
+OidDumperCache: TypeAlias = Dict[int, "Dumper"]
+LoaderCache: TypeAlias = Dict[int, "Loader"]
+
+TEXT = pq.Format.TEXT
+PY_TEXT = PyFormat.TEXT
+
+
+class Transformer(AdaptContext):
+ """
+ An object that can adapt efficiently between Python and PostgreSQL.
+
+ The life cycle of the object is the query, so it is assumed that attributes
+ such as the server version or the connection encoding will not change. The
+ object have its state so adapting several values of the same type can be
+ optimised.
+
+ """
+
+ __module__ = "psycopg.adapt"
+
+ __slots__ = """
+ types formats
+ _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
+ _oid_dumpers _oid_types _row_dumpers _row_loaders
+ """.split()
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ _adapters: "AdaptersMap"
+ _pgresult: Optional["PGresult"]
+ _none_oid: int
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ self._pgresult = self.types = self.formats = None
+
+ # WARNING: don't store context, or you'll create a loop with the Cursor
+ if context:
+ self._adapters = context.adapters
+ self._conn = context.connection
+ else:
+ self._adapters = postgres.adapters
+ self._conn = None
+
+ # mapping fmt, class -> Dumper instance
+ self._dumpers: DefaultDict[PyFormat, DumperCache]
+ self._dumpers = defaultdict(dict)
+
+ # mapping fmt, oid -> Dumper instance
+ # Not often used, so create it only if needed.
+ self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
+ self._oid_dumpers = None
+
+ # mapping fmt, oid -> Loader instance
+ self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
+
+ self._row_dumpers: Optional[List["Dumper"]] = None
+
+ # sequence of load functions from value to python
+ # the length of the result columns
+ self._row_loaders: List[LoadFunc] = []
+
+ # mapping oid -> type sql representation
+ self._oid_types: Dict[int, bytes] = {}
+
+ self._encoding = ""
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ """
+ Return a Transformer from an AdaptContext.
+
+ If the context is a Transformer instance, just return it.
+ """
+ if isinstance(context, Transformer):
+ return context
+ else:
+ return cls(context)
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ return self._conn
+
+ @property
+ def encoding(self) -> str:
+ if not self._encoding:
+ conn = self.connection
+ self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+ return self._encoding
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ return self._adapters
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ return self._pgresult
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None,
+ ) -> None:
+ self._pgresult = result
+
+ if not result:
+ self._nfields = self._ntuples = 0
+ if set_loaders:
+ self._row_loaders = []
+ return
+
+ self._ntuples = result.ntuples
+ nf = self._nfields = result.nfields
+
+ if not set_loaders:
+ return
+
+ if not nf:
+ self._row_loaders = []
+ return
+
+ fmt: pq.Format
+ fmt = result.fformat(0) if format is None else format # type: ignore
+ self._row_loaders = [
+ self.get_loader(result.ftype(i), fmt).load for i in range(nf)
+ ]
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
+ self.types = tuple(types)
+ self.formats = [format] * len(types)
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_loaders = [self.get_loader(oid, format).load for oid in types]
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ nparams = len(params)
+ out: List[Optional[Buffer]] = [None] * nparams
+
+ # If we have dumpers, it means set_dumper_types had been called, in
+ # which case self.types and self.formats are set to sequences of the
+ # right size.
+ if self._row_dumpers:
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ out[i] = self._row_dumpers[i].dump(param)
+ return out
+
+ types = [self._get_none_oid()] * nparams
+ pqformats = [TEXT] * nparams
+
+ for i in range(nparams):
+ param = params[i]
+ if param is None:
+ continue
+ dumper = self.get_dumper(param, formats[i])
+ out[i] = dumper.dump(param)
+ types[i] = dumper.oid
+ pqformats[i] = dumper.format
+
+ self.types = tuple(types)
+ self.formats = pqformats
+
+ return out
+
+ def as_literal(self, obj: Any) -> bytes:
+ dumper = self.get_dumper(obj, PY_TEXT)
+ rv = dumper.quote(obj)
+ # If the result is quoted, and the oid not unknown or text,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ oid = dumper.oid
+ if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
+ try:
+ type_sql = self._oid_types[oid]
+ except KeyError:
+ ti = self.adapters.types.get(oid)
+ if ti:
+ if oid < 8192:
+ # builtin: prefer "timestamptz" to "timestamp with time zone"
+ type_sql = ti.name.encode(self.encoding)
+ else:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+ else:
+ type_sql = b""
+ self._oid_types[oid] = type_sql
+
+ if type_sql:
+ rv = b"%s::%s" % (rv, type_sql)
+
+ if not isinstance(rv, bytes):
+ rv = bytes(rv)
+ return rv
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Return a Dumper instance to dump `!obj`.
+ """
+ # Normally, the type of the object dictates how to dump it
+ key = type(obj)
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._dumpers[format]
+ try:
+ dumper = cache[key]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper(key, format)
+ cache[key] = dumper = dcls(key, self)
+
+ # Check if the dumper requires an upgrade to handle this specific value
+ key1 = dumper.get_key(obj, format)
+ if key1 is key:
+ return dumper
+
+ # If it does, ask the dumper to create its own upgraded version
+ try:
+ return cache[key1]
+ except KeyError:
+ dumper = cache[key1] = dumper.upgrade(obj, format)
+ return dumper
+
+ def _get_none_oid(self) -> int:
+ try:
+ return self._none_oid
+ except AttributeError:
+ pass
+
+ try:
+ rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
+ except KeyError:
+ raise e.InterfaceError("None dumper not found")
+
+ return rv
+
+ def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
+ """
+ Return a Dumper to dump an object to the type with given oid.
+ """
+ if not self._oid_dumpers:
+ self._oid_dumpers = ({}, {})
+
+ # Reuse an existing Dumper class for objects of the same type
+ cache = self._oid_dumpers[format]
+ try:
+ return cache[oid]
+ except KeyError:
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper_by_oid(oid, format)
+ cache[oid] = dumper = dcls(NoneType, self)
+
+ return dumper
+
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
+ res = self._pgresult
+ if not res:
+ raise e.InterfaceError("result not set")
+
+ if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
+ raise e.InterfaceError(
+ f"rows must be included between 0 and {self._ntuples}"
+ )
+
+ records = []
+ for row in range(row0, row1):
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+ records.append(make_row(record))
+
+ return records
+
+ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
+ res = self._pgresult
+ if not res:
+ return None
+
+ if not 0 <= row < self._ntuples:
+ return None
+
+ record: List[Any] = [None] * self._nfields
+ for col in range(self._nfields):
+ val = res.get_value(row, col)
+ if val is not None:
+ record[col] = self._row_loaders[col](val)
+
+ return make_row(record)
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ if len(self._row_loaders) != len(record):
+ raise e.ProgrammingError(
+ f"cannot load sequence of {len(record)} items:"
+ f" {len(self._row_loaders)} loaders registered"
+ )
+
+ return tuple(
+ (self._row_loaders[i](val) if val is not None else None)
+ for i, val in enumerate(record)
+ )
+
+ def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ try:
+ return self._loaders[format][oid]
+ except KeyError:
+ pass
+
+ loader_cls = self._adapters.get_loader(oid, format)
+ if not loader_cls:
+ loader_cls = self._adapters.get_loader(INVALID_OID, format)
+ if not loader_cls:
+ raise e.InterfaceError("unknown oid loader not found")
+ loader = self._loaders[format][oid] = loader_cls(oid, self)
+ return loader
diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py
new file mode 100644
index 0000000..2f1a24d
--- /dev/null
+++ b/psycopg/psycopg/_typeinfo.py
@@ -0,0 +1,461 @@
+"""
+Information about PostgreSQL types
+
+These types allow to read information from the system catalog and provide
+information to the adapters if needed.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+from enum import Enum
+from typing import Any, Dict, Iterator, Optional, overload
+from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import errors as e
+from .abc import AdaptContext
+from .rows import dict_row
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+ from .sql import Identifier
+
+T = TypeVar("T", bound="TypeInfo")
+RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
+
+
+class TypeInfo:
+ """
+ Hold information about a PostgreSQL base type.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ delimiter: str = ",",
+ ):
+ self.name = name
+ self.oid = oid
+ self.array_oid = array_oid
+ self.regtype = regtype or name
+ self.delimiter = delimiter
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__class__.__qualname__}:"
+ f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
+ )
+
+ @overload
+ @classmethod
+ def fetch(
+ cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
+ ) -> Optional[T]:
+ ...
+
+ @overload
+ @classmethod
+ async def fetch(
+ cls: Type[T],
+ conn: "AsyncConnection[Any]",
+ name: Union[str, "Identifier"],
+ ) -> Optional[T]:
+ ...
+
+ @classmethod
+ def fetch(
+ cls: Type[T],
+ conn: "Union[Connection[Any], AsyncConnection[Any]]",
+ name: Union[str, "Identifier"],
+ ) -> Any:
+ """Query a system catalog to read information about a type."""
+ from .sql import Composable
+ from .connection_async import AsyncConnection
+
+ if isinstance(name, Composable):
+ name = name.as_string(conn)
+
+ if isinstance(conn, AsyncConnection):
+ return cls._fetch_async(conn, name)
+
+ # This might result in a nested transaction. What we want is to leave
+ # the function with the connection in the state we found (either idle
+ # or intrans)
+ try:
+ with conn.transaction():
+ with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ async def _fetch_async(
+ cls: Type[T], conn: "AsyncConnection[Any]", name: str
+ ) -> Optional[T]:
+ """
+ Query a system catalog to read information about a type.
+
+ Similar to `fetch()` but can use an asynchronous connection.
+ """
+ try:
+ async with conn.transaction():
+ async with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ await cur.execute(cls._get_info_query(conn), {"name": name})
+ recs = await cur.fetchall()
+ except e.UndefinedObject:
+ return None
+
+ return cls._from_records(name, recs)
+
+ @classmethod
+ def _from_records(
+ cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
+ ) -> Optional[T]:
+ if len(recs) == 1:
+ return cls(**recs[0])
+ elif not recs:
+ return None
+ else:
+ raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
+
+ def register(self, context: Optional[AdaptContext] = None) -> None:
+ """
+ Register the type information, globally or in the specified `!context`.
+ """
+ if context:
+ types = context.adapters.types
+ else:
+ from . import postgres
+
+ types = postgres.types
+
+ types.add(self)
+
+ if self.array_oid:
+ from .types.array import register_array
+
+ register_array(self, context)
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ typname AS name, oid, typarray AS array_oid,
+ oid::regtype::text AS regtype, typdelim AS delimiter
+FROM pg_type t
+WHERE t.oid = %(name)s::regtype
+ORDER BY t.oid
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ """Method called by the `!registry` when the object is added there."""
+ pass
+
+
+class RangeInfo(TypeInfo):
+ """Manage information about a range type."""
+
+ __module__ = "psycopg.types.range"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngtypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map ranges subtypes to info
+ registry._registry[RangeInfo, self.subtype_oid] = self
+
+
+class MultirangeInfo(TypeInfo):
+ """Manage information about a multirange type."""
+
+ # TODO: expose to multirange module once added
+ # __module__ = "psycopg.types.multirange"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ range_oid: int,
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.range_oid = range_oid
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ if conn.info.server_version < 140000:
+ raise e.NotSupportedError(
+ "multirange types are only available from PostgreSQL 14"
+ )
+ return """\
+SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
+FROM pg_type t
+JOIN pg_range r ON t.oid = r.rngmultitypid
+WHERE t.oid = %(name)s::regtype
+"""
+
+ def _added(self, registry: "TypesRegistry") -> None:
+ # Map multiranges ranges and subtypes to info
+ registry._registry[MultirangeInfo, self.range_oid] = self
+ registry._registry[MultirangeInfo, self.subtype_oid] = self
+
+
+class CompositeInfo(TypeInfo):
+ """Manage information about a composite type."""
+
+ __module__ = "psycopg.types.composite"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ *,
+ regtype: str = "",
+ field_names: Sequence[str],
+ field_types: Sequence[int],
+ ):
+ super().__init__(name, oid, array_oid, regtype=regtype)
+ self.field_names = field_names
+ self.field_types = field_types
+ # Will be set by register() if the `factory` is a type
+ self.python_type: Optional[type] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ t.oid::regtype::text AS regtype,
+ coalesce(a.fnames, '{}') AS field_names,
+ coalesce(a.ftypes, '{}') AS field_types
+FROM pg_type t
+LEFT JOIN (
+ SELECT
+ attrelid,
+ array_agg(attname) AS fnames,
+ array_agg(atttypid) AS ftypes
+ FROM (
+ SELECT a.attrelid, a.attname, a.atttypid
+ FROM pg_attribute a
+ JOIN pg_type t ON t.typrelid = a.attrelid
+ WHERE t.oid = %(name)s::regtype
+ AND a.attnum > 0
+ AND NOT a.attisdropped
+ ORDER BY a.attnum
+ ) x
+ GROUP BY attrelid
+) a ON a.attrelid = t.typrelid
+WHERE t.oid = %(name)s::regtype
+"""
+
+
+class EnumInfo(TypeInfo):
+ """Manage information about an enum type."""
+
+ __module__ = "psycopg.types.enum"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ labels: Sequence[str],
+ ):
+ super().__init__(name, oid, array_oid)
+ self.labels = labels
+ # Will be set by register_enum()
+ self.enum: Optional[Type[Enum]] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT name, oid, array_oid, array_agg(label) AS labels
+FROM (
+ SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ e.enumlabel AS label
+ FROM pg_type t
+ LEFT JOIN pg_enum e
+ ON e.enumtypid = t.oid
+ WHERE t.oid = %(name)s::regtype
+ ORDER BY e.enumsortorder
+) x
+GROUP BY name, oid, array_oid
+"""
+
+
+class TypesRegistry:
+ """
+ Container for the information about types in a database.
+ """
+
+ __module__ = "psycopg.types"
+
+ def __init__(self, template: Optional["TypesRegistry"] = None):
+ self._registry: Dict[RegistryKey, TypeInfo]
+
+ # Make a shallow copy: it will become a proper copy if the registry
+ # is edited.
+ if template:
+ self._registry = template._registry
+ self._own_state = False
+ template._own_state = False
+ else:
+ self.clear()
+
+ def clear(self) -> None:
+ self._registry = {}
+ self._own_state = True
+
+ def add(self, info: TypeInfo) -> None:
+ self._ensure_own_state()
+ if info.oid:
+ self._registry[info.oid] = info
+ if info.array_oid:
+ self._registry[info.array_oid] = info
+ self._registry[info.name] = info
+
+ if info.regtype and info.regtype not in self._registry:
+ self._registry[info.regtype] = info
+
+ # Allow info to customise further their relation with the registry
+ info._added(self)
+
+ def __iter__(self) -> Iterator[TypeInfo]:
+ seen = set()
+ for t in self._registry.values():
+ if id(t) not in seen:
+ seen.add(id(t))
+ yield t
+
+ @overload
+ def __getitem__(self, key: Union[str, int]) -> TypeInfo:
+ ...
+
+ @overload
+ def __getitem__(self, key: Tuple[Type[T], int]) -> T:
+ ...
+
+ def __getitem__(self, key: RegistryKey) -> TypeInfo:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Raise KeyError if not found.
+ """
+ if isinstance(key, str):
+ if key.endswith("[]"):
+ key = key[:-2]
+ elif not isinstance(key, (int, tuple)):
+ raise TypeError(f"the key must be an oid or a name, got {type(key)}")
+ try:
+ return self._registry[key]
+ except KeyError:
+ raise KeyError(f"couldn't find the type {key!r} in the types registry")
+
+ @overload
+ def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
+ ...
+
+ def get(self, key: RegistryKey) -> Optional[TypeInfo]:
+ """
+ Return info about a type, specified by name or oid
+
+ :param key: the name or oid of the type to look for.
+
+ Unlike `__getitem__`, return None if not found.
+ """
+ try:
+ return self[key]
+ except KeyError:
+ return None
+
+ def get_oid(self, name: str) -> int:
+ """
+ Return the oid of a PostgreSQL type by name.
+
+ :param key: the name of the type to look for.
+
+ Return the array oid if the type ends with "``[]``"
+
+ Raise KeyError if the name is unknown.
+ """
+ t = self[name]
+ if name.endswith("[]"):
+ return t.array_oid
+ else:
+ return t.oid
+
+ def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
+ """
+ Return info about a `TypeInfo` subclass by its element name or oid.
+
+ :param cls: the subtype of `!TypeInfo` to look for. Currently
+ supported are `~psycopg.types.range.RangeInfo` and
+ `~psycopg.types.multirange.MultirangeInfo`.
+ :param subtype: The name or OID of the subtype of the element to look for.
+ :return: The `!TypeInfo` object of class `!cls` whose subtype is
+ `!subtype`. `!None` if the element or its range are not found.
+ """
+ try:
+ info = self[subtype]
+ except KeyError:
+ return None
+ return self.get((cls, info.oid))
+
+ def _ensure_own_state(self) -> None:
+ # Time to write! so, copy.
+ if not self._own_state:
+ self._registry = self._registry.copy()
+ self._own_state = True
diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py
new file mode 100644
index 0000000..813ed62
--- /dev/null
+++ b/psycopg/psycopg/_tz.py
@@ -0,0 +1,44 @@
+"""
+Timezone utility functions.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import Dict, Optional, Union
+from datetime import timezone, tzinfo
+
+from .pq.abc import PGconn
+from ._compat import ZoneInfo
+
+logger = logging.getLogger("psycopg")
+
+_timezones: Dict[Union[None, bytes], tzinfo] = {
+ None: timezone.utc,
+ b"UTC": timezone.utc,
+}
+
+
+def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
+ """Return the Python timezone info of the connection's timezone."""
+ tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
+ try:
+ return _timezones[tzname]
+ except KeyError:
+ sname = tzname.decode() if tzname else "UTC"
+ try:
+ zi: tzinfo = ZoneInfo(sname)
+ except (KeyError, OSError):
+ logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
+ zi = timezone.utc
+ except Exception as ex:
+ logger.warning(
+ "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
+ sname,
+ type(ex).__name__,
+ ex,
+ )
+ zi = timezone.utc
+
+ _timezones[tzname] = zi
+ return zi
diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py
new file mode 100644
index 0000000..f861741
--- /dev/null
+++ b/psycopg/psycopg/_wrappers.py
@@ -0,0 +1,137 @@
+"""
+Wrappers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+# These types are implemented here but exposed by `psycopg.types.numeric`.
+# They are defined here to avoid a circular import.
+_MODULE = "psycopg.types.numeric"
+
+
+class Int2(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int2":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int4(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Int8(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Int8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class IntNumeric(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "IntNumeric":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float4(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float4":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Float8(float):
+ """
+ Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: float) -> "Float8":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
+
+
+class Oid(int):
+ """
+ Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
+ """
+
+ __module__ = _MODULE
+ __slots__ = ()
+
+ def __new__(cls, arg: int) -> "Oid":
+ return super().__new__(cls, arg)
+
+ def __str__(self) -> str:
+ return super().__repr__()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({super().__repr__()})"
diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py
new file mode 100644
index 0000000..80c8fbf
--- /dev/null
+++ b/psycopg/psycopg/abc.py
@@ -0,0 +1,266 @@
+"""
+Protocol objects representing different implementations of the same classes.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Generator, Mapping
+from typing import List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from . import pq
+from ._enums import PyFormat as PyFormat
+from ._compat import Protocol, LiteralString
+
+if TYPE_CHECKING:
+ from . import sql
+ from .rows import Row, RowMaker
+ from .pq.abc import PGresult
+ from .waiting import Wait, Ready
+ from .connection import BaseConnection
+ from ._adapters_map import AdaptersMap
+
+NoneType: type = type(None)
+
+# An object implementing the buffer protocol
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
+Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
+PipelineCommand: TypeAlias = Callable[[], None]
+DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+
+# Waiting protocol types
+
+RV = TypeVar("RV")
+
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+"""Generator for processes where the connection file number can change.
+
+This can happen in connection and reset, but not in normal querying.
+"""
+
+PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+"""Generator for processes where the connection file number won't change.
+"""
+
+
+class WaitFunc(Protocol):
+ """
+ Wait on the connection which generated `PQgen` and return its final result.
+ """
+
+ def __call__(
+ self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
+ ) -> RV:
+ ...
+
+
+# Adaptation types
+
+DumpFunc: TypeAlias = Callable[[Any], Buffer]
+LoadFunc: TypeAlias = Callable[[Buffer], Any]
+
+
+class AdaptContext(Protocol):
+ """
+ A context describing how types are adapted.
+
+ Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
+ `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
+
+ Note that this is a `~typing.Protocol`, so objects implementing
+ `!AdaptContext` don't need to explicitly inherit from this class.
+
+ """
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ """The adapters configuration that this object uses."""
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ """The connection used by this object, if available.
+
+ :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
+ """
+ ...
+
+
+class Dumper(Protocol):
+ """
+ Convert Python objects of type `!cls` to PostgreSQL representation.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `dump()` method produces,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ oid: int
+ """The oid to pass to the server, if known; 0 otherwise (class attribute)."""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ ...
+
+ def dump(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to PostgreSQL representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """Convert the object `!obj` to escaped representation.
+
+ :param obj: the object to convert.
+ """
+ ...
+
+ def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
+ """Return an alternative key to upgrade the dumper to represent `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Normally the type of the object is all it takes to define how to dump
+ the object to the database. For instance, a Python `~datetime.date` can
+ be simply converted into a PostgreSQL :sql:`date`.
+
+ In a few cases, just the type is not enough. For example:
+
+ - A Python `~datetime.datetime` could be represented as a
+ :sql:`timestamptz` or a :sql:`timestamp`, according to whether it
+ specifies a `!tzinfo` or not.
+
+ - A Python int could be stored as several Postgres types: int2, int4,
+ int8, numeric. If a type too small is used, it may result in an
+ overflow. If a type too large is used, PostgreSQL may not want to
+ cast it to a smaller type.
+
+ - Python lists should be dumped according to the type they contain to
+ convert them to e.g. array of strings, array of ints (and which
+ size of int?...)
+
+ In these cases, a dumper can implement `!get_key()` and return a new
+ class, or sequence of classes, that can be used to identify the same
+ dumper again. If the mechanism is not needed, the method should return
+ the same `!cls` object passed in the constructor.
+
+ If a dumper implements `get_key()` it should also implement
+ `upgrade()`.
+
+ """
+ ...
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """Return a new dumper to manage `!obj`.
+
+ :param obj: The object to convert
+ :param format: The format to convert to
+
+ Once `Transformer.get_dumper()` has been notified by `get_key()` that
+ this Dumper class cannot handle `!obj` itself, it will invoke
+ `!upgrade()`, which should return a new `Dumper` instance, which will
+ be reused for every objects for which `!get_key()` returns the same
+ result.
+ """
+ ...
+
+
+class Loader(Protocol):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format
+ """
+ The format that this class `load()` method can convert,
+ `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
+
+ This is a class attribute.
+ """
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ ...
+
+ def load(self, data: Buffer) -> Any:
+ """
+ Convert the data returned by the database into a Python object.
+
+ :param data: the data to convert.
+ """
+ ...
+
+
+class Transformer(Protocol):
+
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+
+ def __init__(self, context: Optional[AdaptContext] = None):
+ ...
+
+ @classmethod
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
+ ...
+
+ @property
+ def connection(self) -> Optional["BaseConnection[Any]"]:
+ ...
+
+ @property
+ def encoding(self) -> str:
+ ...
+
+ @property
+ def adapters(self) -> "AdaptersMap":
+ ...
+
+ @property
+ def pgresult(self) -> Optional["PGresult"]:
+ ...
+
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None
+ ) -> None:
+ ...
+
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ ...
+
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[Buffer]]:
+ ...
+
+ def as_literal(self, obj: Any) -> bytes:
+ ...
+
+ def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
+ ...
+
+ def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
+ ...
+
+ def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
+ ...
+
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
+ ...
+
+ def get_loader(self, oid: int, format: pq.Format) -> Loader:
+ ...
diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py
new file mode 100644
index 0000000..7ec4a55
--- /dev/null
+++ b/psycopg/psycopg/adapt.py
@@ -0,0 +1,162 @@
+"""
+Entry point into the adaptation system.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Type, TYPE_CHECKING
+
+from . import pq, abc
+from . import _adapters_map
+from ._enums import PyFormat as PyFormat
+from ._cmodule import _psycopg
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
+
+AdaptersMap = _adapters_map.AdaptersMap
+Buffer = abc.Buffer
+
+ORD_BS = ord("\\")
+
+
+class Dumper(abc.Dumper, ABC):
+ """
+ Convert Python object of the type `!cls` to PostgreSQL representation.
+ """
+
+ oid: int = 0
+ """The oid to pass to the server, if known."""
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data dumped."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ self.cls = cls
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ def __repr__(self) -> str:
+ return (
+ f"<{type(self).__module__}.{type(self).__qualname__}"
+ f" (oid={self.oid}) at 0x{id(self):x}>"
+ )
+
+ @abstractmethod
+ def dump(self, obj: Any) -> Buffer:
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ """
+ By default return the `dump()` value quoted and sanitised, so
+ that the result can be used to build a SQL string. This works well
+ for most types and you won't likely have to implement this method in a
+ subclass.
+ """
+ value = self.dump(obj)
+
+ if self.connection:
+ esc = pq.Escaping(self.connection.pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = pq.Escaping()
+ out = esc.escape_string(value)
+
+ # b"\\" in memoryview doesn't work so search for the ascii value
+ if ORD_BS not in out:
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ return b"'" + out + b"'"
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ rv: bytes = b" E'" + out + b"'"
+ if esc.escape_string(b"\\") == b"\\":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+ def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
+ """
+ Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation returns the `!cls` passed in the constructor.
+ Subclasses needing to specialise the PostgreSQL type according to the
+ *value* of the object dumped (not only according to to its type)
+ should override this class.
+
+ """
+ return self.cls
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ """
+ Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
+ `~psycopg.abc.Dumper` protocol. Look at its definition for details.
+
+ This implementation just returns `!self`. If a subclass implements
+ `get_key()` it should probably override `!upgrade()` too.
+ """
+ return self
+
+
+class Loader(abc.Loader, ABC):
+ """
+ Convert PostgreSQL values with type OID `!oid` to Python objects.
+ """
+
+ format: pq.Format = pq.Format.TEXT
+ """The format of the data loaded."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ self.oid = oid
+ self.connection: Optional["BaseConnection[Any]"] = (
+ context.connection if context else None
+ )
+
+ @abstractmethod
+ def load(self, data: Buffer) -> Any:
+ """Convert a PostgreSQL value to a Python object."""
+ ...
+
+
+Transformer: Type["abc.Transformer"]
+
+# Override it with fast object if available
+if _psycopg:
+ Transformer = _psycopg.Transformer
+else:
+ from . import _transform
+
+ Transformer = _transform.Transformer
+
+
+class RecursiveDumper(Dumper):
+ """Dumper with a transformer to help dumping recursive types."""
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self._tx = Transformer.from_context(context)
+
+
+class RecursiveLoader(Loader):
+ """Loader with a transformer to help loading recursive types."""
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer.from_context(context)
diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py
new file mode 100644
index 0000000..6271ec5
--- /dev/null
+++ b/psycopg/psycopg/client_cursor.py
@@ -0,0 +1,95 @@
+"""
+psycopg client-side binding cursors
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from typing import Optional, Tuple, TYPE_CHECKING
+from functools import partial
+
+from ._queries import PostgresQuery, PostgresClientQuery
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params
+from .rows import Row
+from .cursor import BaseCursor, Cursor
+from ._preparing import Prepare
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from typing import Any # noqa: F401
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+
+class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
+ def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
+ """
+ Return the query and parameters merged.
+
+ Parameters are adapted and merged to the query the same way that
+ `!execute()` would do.
+
+ """
+ self._tx = adapt.Transformer(self)
+ pgq = self._convert_query(query, params)
+ return pgq.query.decode(self._tx.encoding)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if fmt == BINARY:
+ raise e.NotSupportedError(
+ "client-side cursors don't support binary results"
+ )
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(self._pgconn.send_query_params, query.query, None)
+ )
+ elif force_extended:
+ self._pgconn.send_query_params(query.query, None)
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return (Prepare.NO, b"")
+
+
+class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+
+
+class AsyncClientCursor(
+ ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py
new file mode 100644
index 0000000..78ad577
--- /dev/null
+++ b/psycopg/psycopg/connection.py
@@ -0,0 +1,1031 @@
+"""
+psycopg connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+import threading
+from types import TracebackType
+from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
+from typing import overload, TYPE_CHECKING
+from weakref import ref, ReferenceType
+from warnings import warn
+from functools import partial
+from contextlib import contextmanager
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from . import waiting
+from . import postgres
+from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import PQGen, PQGenConn
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .cursor import Cursor
+from ._compat import LiteralString
+from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from ._pipeline import BasePipeline, Pipeline
+from .generators import notifies, connect, execute
+from ._encodings import pgconn_encoding
+from ._preparing import PrepareManager
+from .transaction import Transaction
+from .server_cursor import ServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn, PGresult
+ from psycopg_pool.base import BasePool
+
+
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class Notify(NamedTuple):
+ """An asynchronous notification received from the database."""
+
+ channel: str
+ """The name of the channel on which the notification was received."""
+
+ payload: str
+ """The message attached to the notification."""
+
+ pid: int
+ """The PID of the backend process which sent the notification."""
+
+
+Notify.__module__ = "psycopg"
+
+NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
+NotifyHandler: TypeAlias = Callable[[Notify], None]
+
+
+class BaseConnection(Generic[Row]):
+ """
+ Base class for different types of connections.
+
+ Share common functionalities such as access to the wrapped PGconn, but
+ allow different interfaces (sync/async).
+ """
+
+ # DBAPI2 exposed exceptions
+ Warning = e.Warning
+ Error = e.Error
+ InterfaceError = e.InterfaceError
+ DatabaseError = e.DatabaseError
+ DataError = e.DataError
+ OperationalError = e.OperationalError
+ IntegrityError = e.IntegrityError
+ InternalError = e.InternalError
+ ProgrammingError = e.ProgrammingError
+ NotSupportedError = e.NotSupportedError
+
+ # Enums useful for the connection
+ ConnStatus = pq.ConnStatus
+ TransactionStatus = pq.TransactionStatus
+
+ def __init__(self, pgconn: "PGconn"):
+ self.pgconn = pgconn
+ self._autocommit = False
+
+ # None, but set to a copy of the global adapters map as soon as requested.
+ self._adapters: Optional[AdaptersMap] = None
+
+ self._notice_handlers: List[NoticeHandler] = []
+ self._notify_handlers: List[NotifyHandler] = []
+
+ # Number of transaction blocks currently entered
+ self._num_transactions = 0
+
+ self._closed = False # closed by an explicit close()
+ self._prepared: PrepareManager = PrepareManager()
+ self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared
+
+ wself = ref(self)
+ pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+ pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
+
+ # Attribute is only set if the connection is from a pool so we can tell
+ # apart a connection in the pool too (when _pool = None)
+ self._pool: Optional["BasePool[Any]"]
+
+ self._pipeline: Optional[BasePipeline] = None
+
+ # Time after which the connection should be closed
+ self._expire_at: float
+
+ self._isolation_level: Optional[IsolationLevel] = None
+ self._read_only: Optional[bool] = None
+ self._deferrable: Optional[bool] = None
+ self._begin_statement = b""
+
+ def __del__(self) -> None:
+ # If fails on connection we might not have this attribute yet
+ if not hasattr(self, "pgconn"):
+ return
+
+ # Connection correctly closed
+ if self.closed:
+ return
+
+ # Connection in a pool so terminating with the program is normal
+ if hasattr(self, "_pool"):
+ return
+
+ warn(
+ f"connection {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the connection",
+ ResourceWarning,
+ )
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @property
+ def closed(self) -> bool:
+ """`!True` if the connection is closed."""
+ return self.pgconn.status == BAD
+
+ @property
+ def broken(self) -> bool:
+ """
+ `!True` if the connection was interrupted.
+
+ A broken connection is always `closed`, but wasn't closed in a clean
+ way, such as using `close()` or a `!with` block.
+ """
+ return self.pgconn.status == BAD and not self._closed
+
+ @property
+ def autocommit(self) -> bool:
+ """The autocommit state of the connection."""
+ return self._autocommit
+
+ @autocommit.setter
+ def autocommit(self, value: bool) -> None:
+ self._set_autocommit(value)
+
+ def _set_autocommit(self, value: bool) -> None:
+ raise NotImplementedError
+
+ def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
+ yield from self._check_intrans_gen("autocommit")
+ self._autocommit = bool(value)
+
+ @property
+ def isolation_level(self) -> Optional[IsolationLevel]:
+ """
+ The isolation level of the new transactions started on the connection.
+ """
+ return self._isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._set_isolation_level(value)
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ raise NotImplementedError
+
+ def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
+ yield from self._check_intrans_gen("isolation_level")
+ self._isolation_level = IsolationLevel(value) if value is not None else None
+ self._begin_statement = b""
+
+ @property
+ def read_only(self) -> Optional[bool]:
+ """
+ The read-only state of the new transactions started on the connection.
+ """
+ return self._read_only
+
+ @read_only.setter
+ def read_only(self, value: Optional[bool]) -> None:
+ self._set_read_only(value)
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("read_only")
+ self._read_only = bool(value)
+ self._begin_statement = b""
+
+ @property
+ def deferrable(self) -> Optional[bool]:
+ """
+ The deferrable state of the new transactions started on the connection.
+ """
+ return self._deferrable
+
+ @deferrable.setter
+ def deferrable(self, value: Optional[bool]) -> None:
+ self._set_deferrable(value)
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ raise NotImplementedError
+
+ def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
+ yield from self._check_intrans_gen("deferrable")
+ self._deferrable = bool(value)
+ self._begin_statement = b""
+
+ def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
+ # Raise an exception if we are in a transaction
+ status = self.pgconn.transaction_status
+ if status == IDLE and self._pipeline:
+ yield from self._pipeline._sync_gen()
+ status = self.pgconn.transaction_status
+ if status != IDLE:
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection.transaction() context in progress"
+ )
+ else:
+ raise e.ProgrammingError(
+ f"can't change {attribute!r} now: "
+ "connection in transaction status "
+ f"{pq.TransactionStatus(status).name}"
+ )
+
+ @property
+ def info(self) -> ConnectionInfo:
+ """A `ConnectionInfo` attribute to inspect connection properties."""
+ return ConnectionInfo(self.pgconn)
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ self._adapters = AdaptersMap(postgres.adapters)
+
+ return self._adapters
+
+ @property
+ def connection(self) -> "BaseConnection[Row]":
+ # implement the AdaptContext protocol
+ return self
+
+ def fileno(self) -> int:
+ """Return the file descriptor of the connection.
+
+ This function allows to use the connection as file-like object in
+ functions waiting for readiness, such as the ones defined in the
+ `selectors` module.
+ """
+ return self.pgconn.socket
+
+ def cancel(self) -> None:
+ """Cancel the current operation on the connection."""
+ # No-op if the connection is closed
+ # this allows to use the method as callback handler without caring
+ # about its life.
+ if self.closed:
+ return
+
+ if self._tpc and self._tpc[1]:
+ raise e.ProgrammingError(
+ "cancel() cannot be used with a prepared two-phase transaction"
+ )
+
+ c = self.pgconn.get_cancel()
+ c.cancel()
+
+ def add_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Register a callable to be invoked when a notice message is received.
+
+ :param callback: the callback to call upon message received.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.append(callback)
+
+ def remove_notice_handler(self, callback: NoticeHandler) -> None:
+ """
+ Unregister a notice message callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+ """
+ self._notice_handlers.remove(callback)
+
+ @staticmethod
+ def _notice_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
+ ) -> None:
+ self = wself()
+ if not (self and self._notice_handlers):
+ return
+
+ diag = e.Diagnostic(res, pgconn_encoding(self.pgconn))
+ for cb in self._notice_handlers:
+ try:
+ cb(diag)
+ except Exception as ex:
+ logger.exception("error processing notice callback '%s': %s", cb, ex)
+
+ def add_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Register a callable to be invoked whenever a notification is received.
+
+ :param callback: the callback to call upon notification received.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.append(callback)
+
+ def remove_notify_handler(self, callback: NotifyHandler) -> None:
+ """
+ Unregister a notification callable previously registered.
+
+ :param callback: the callback to remove.
+ :type callback: Callable[[~psycopg.Notify], None]
+ """
+ self._notify_handlers.remove(callback)
+
+ @staticmethod
+ def _notify_handler(
+ wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+ ) -> None:
+ self = wself()
+ if not (self and self._notify_handlers):
+ return
+
+ enc = pgconn_encoding(self.pgconn)
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ for cb in self._notify_handlers:
+ cb(n)
+
+ @property
+ def prepare_threshold(self) -> Optional[int]:
+ """
+ Number of times a query is executed before it is prepared.
+
+ - If it is set to 0, every query is prepared the first time it is
+ executed.
+ - If it is set to `!None`, prepared statements are disabled on the
+ connection.
+
+ Default value: 5
+ """
+ return self._prepared.prepare_threshold
+
+ @prepare_threshold.setter
+ def prepare_threshold(self, value: Optional[int]) -> None:
+ self._prepared.prepare_threshold = value
+
+ @property
+ def prepared_max(self) -> int:
+ """
+ Maximum number of prepared statements on the connection.
+
+ Default value: 100
+ """
+ return self._prepared.prepared_max
+
+ @prepared_max.setter
+ def prepared_max(self, value: int) -> None:
+ self._prepared.prepared_max = value
+
+ # Generators to perform high-level operations on the connection
+ #
+ # These operations are expressed in terms of non-blocking generators
+ # and the task of waiting when needed (when the generators yield) is left
+ # to the connections subclass, which might wait either in blocking mode
+ # or through asyncio.
+ #
+ # All these generators assume exclusive access to the connection: subclasses
+ # should have a lock and hold it before calling and consuming them.
+
+ @classmethod
+ def _connect_gen(
+ cls: Type[ConnectionType],
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ ) -> PQGenConn[ConnectionType]:
+ """Generator to connect to the database and create a new instance."""
+ pgconn = yield from connect(conninfo)
+ conn = cls(pgconn)
+ conn._autocommit = bool(autocommit)
+ return conn
+
+ def _exec_command(
+ self, command: Query, result_format: pq.Format = TEXT
+ ) -> PQGen[Optional["PGresult"]]:
+ """
+ Generator to send a command and receive the result to the backend.
+
+ Only used to implement internal commands such as "commit", with eventual
+ arguments bound client-side. The cursor can do more complex stuff.
+ """
+ self._check_connection_ok()
+
+ if isinstance(command, str):
+ command = command.encode(pgconn_encoding(self.pgconn))
+ elif isinstance(command, Composable):
+ command = command.as_bytes(self)
+
+ if self._pipeline:
+ cmd = partial(
+ self.pgconn.send_query_params,
+ command,
+ None,
+ result_format=result_format,
+ )
+ self._pipeline.command_queue.append(cmd)
+ self._pipeline.result_queue.append(None)
+ return None
+
+ self.pgconn.send_query_params(command, None, result_format=result_format)
+
+ result = (yield from execute(self.pgconn))[-1]
+ if result.status != COMMAND_OK and result.status != TUPLES_OK:
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+ else:
+ raise e.InterfaceError(
+ f"unexpected result {pq.ExecStatus(result.status).name}"
+ f" from command {command.decode()!r}"
+ )
+ return result
+
+ def _check_connection_ok(self) -> None:
+ if self.pgconn.status == OK:
+ return
+
+ if self.pgconn.status == BAD:
+ raise e.OperationalError("the connection is closed")
+ raise e.InterfaceError(
+ "cannot execute operations: the connection is"
+ f" in status {self.pgconn.status}"
+ )
+
+ def _start_query(self) -> PQGen[None]:
+ """Generator to start a transaction if necessary."""
+ if self._autocommit:
+ return
+
+ if self.pgconn.transaction_status != IDLE:
+ return
+
+ yield from self._exec_command(self._get_tx_start_command())
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _get_tx_start_command(self) -> bytes:
+ if self._begin_statement:
+ return self._begin_statement
+
+ parts = [b"BEGIN"]
+
+ if self.isolation_level is not None:
+ val = IsolationLevel(self.isolation_level)
+ parts.append(b"ISOLATION LEVEL")
+ parts.append(val.name.replace("_", " ").encode())
+
+ if self.read_only is not None:
+ parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
+
+ if self.deferrable is not None:
+ parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
+
+ self._begin_statement = b" ".join(parts)
+ return self._begin_statement
+
+ def _commit_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.commit()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit commit() forbidden within a Transaction "
+ "context. (Transaction will be automatically committed "
+ "on successful exit from context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "commit() cannot be used during a two-phase transaction"
+ )
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"COMMIT")
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _rollback_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.rollback()`."""
+ if self._num_transactions:
+ raise e.ProgrammingError(
+ "Explicit rollback() forbidden within a Transaction "
+ "context. (Either raise Rollback() or allow "
+ "an exception to propagate out of the context.)"
+ )
+ if self._tpc:
+ raise e.ProgrammingError(
+ "rollback() cannot be used during a two-phase transaction"
+ )
+
+ # Get out of a "pipeline aborted" state
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ if self.pgconn.transaction_status == IDLE:
+ return
+
+ yield from self._exec_command(b"ROLLBACK")
+ self._prepared.clear()
+ for cmd in self._prepared.get_maintenance_commands():
+ yield from self._exec_command(cmd)
+
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+ """
+ Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
+
+ The argument types and constraints are explained in
+ :ref:`two-phase-commit`.
+
+ The values passed to the method will be available on the returned
+ object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
+ """
+ self._check_tpc()
+ return Xid.from_parts(format_id, gtrid, bqual)
+
+ def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+ self._check_tpc()
+
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self.pgconn.transaction_status != IDLE:
+ raise e.ProgrammingError(
+ "can't start two-phase transaction: connection in status"
+ f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
+ )
+
+ if self._autocommit:
+ raise e.ProgrammingError(
+ "can't use two-phase transactions in autocommit mode"
+ )
+
+ self._tpc = (xid, False)
+ yield from self._exec_command(self._get_tx_start_command())
+
+ def _tpc_prepare_gen(self) -> PQGen[None]:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' must be called inside a two-phase transaction"
+ )
+ if self._tpc[1]:
+ raise e.ProgrammingError(
+ "'tpc_prepare()' cannot be used during a prepared two-phase transaction"
+ )
+ xid = self._tpc[0]
+ self._tpc = (xid, True)
+ yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
+ if self._pipeline:
+ yield from self._pipeline._sync_gen()
+
+ def _tpc_finish_gen(
+ self, action: LiteralString, xid: Union[Xid, str, None]
+ ) -> PQGen[None]:
+ fname = f"tpc_{action.lower()}()"
+ if xid is None:
+ if not self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} without xid must must be"
+ " called inside a two-phase transaction"
+ )
+ xid = self._tpc[0]
+ else:
+ if self._tpc:
+ raise e.ProgrammingError(
+ f"{fname} with xid must must be called"
+ " outside a two-phase transaction"
+ )
+ if not isinstance(xid, Xid):
+ xid = Xid.from_string(xid)
+
+ if self._tpc and not self._tpc[1]:
+ meth: Callable[[], PQGen[None]]
+ meth = getattr(self, f"_{action.lower()}_gen")
+ self._tpc = None
+ yield from meth()
+ else:
+ yield from self._exec_command(
+ SQL("{} PREPARED {}").format(SQL(action), str(xid))
+ )
+ self._tpc = None
+
+ def _check_tpc(self) -> None:
+ """Raise NotSupportedError if TPC is not supported."""
+ # TPC supported on every supported PostgreSQL version.
+ pass
+
+
+class Connection(BaseConnection[Row]):
+ """
+ Wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[Cursor[Row]]
+ server_cursor_factory: Type[ServerCursor[Row]]
+ row_factory: RowFactory[Row]
+ _pipeline: Optional[Pipeline]
+ _Self = TypeVar("_Self", bound="Connection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = threading.Lock()
+ self.cursor_factory = Cursor
+ self.server_cursor_factory = ServerCursor
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[Cursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "Connection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: Optional[RowFactory[Row]] = None,
+ cursor_factory: Optional[Type[Cursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Any,
+ ) -> "Connection[Any]":
+ """
+ Connect to a database server and return a new `Connection` instance.
+ """
+ params = cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ self.close()
+
+ @classmethod
+ def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ :param conninfo: Connection string as received by `~Connection.connect()`.
+ :param kwargs: Overriding connection arguments as received by `!connect()`.
+ :return: Connection arguments merged and eventually modified, in a
+ format similar to `~conninfo.conninfo_to_dict()`.
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ return params
+
+ def close(self) -> None:
+ """Close the database connection."""
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> Cursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+ ) -> Cursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: RowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> ServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[RowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[Cursor[Any], ServerCursor[Any]]:
+ """
+ Return a new cursor to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[Cursor[Any], ServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> Cursor[Row]:
+ """Execute a query and return a cursor to read its results."""
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def commit(self) -> None:
+ """Commit any pending transaction to the database."""
+ with self.lock:
+ self.wait(self._commit_gen())
+
+ def rollback(self) -> None:
+ """Roll back to the start of any pending transaction."""
+ with self.lock:
+ self.wait(self._rollback_gen())
+
+ @contextmanager
+ def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> Iterator[Transaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :param savepoint_name: Name of the savepoint used to manage a nested
+ transaction. If `!None`, one will be chosen automatically.
+ :param force_rollback: Roll back the transaction at the end of the
+ block even if there were no error (e.g. to try a no-op process).
+ :rtype: Transaction
+ """
+ tx = Transaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ with tx:
+ yield tx
+
+ def notifies(self) -> Generator[Notify, None, None]:
+ """
+ Yield `Notify` objects as soon as they are received from the database.
+ """
+ while True:
+ with self.lock:
+ try:
+ ns = self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @contextmanager
+ def pipeline(self) -> Iterator[Pipeline]:
+ """Switch the connection into pipeline mode."""
+ with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = Pipeline(self)
+
+ try:
+ with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ """
+ Consume a generator operating on the connection.
+
+ The function must be used on generators that don't change connection
+ fd (i.e. not on connect and reset).
+ """
+ try:
+ return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except KeyboardInterrupt:
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # the connection will remain stuck in ACTIVE state.
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ """Consume a connection generator."""
+ return waiting.wait_conn(gen, timeout=timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ with self.lock:
+ self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ with self.lock:
+ self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ with self.lock:
+ self.wait(self._set_deferrable_gen(value))
+
+ def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ """
+ Begin a TPC transaction with the given transaction ID `!xid`.
+ """
+ with self.lock:
+ self.wait(self._tpc_begin_gen(xid))
+
+ def tpc_prepare(self) -> None:
+ """
+ Perform the first phase of a transaction started with `tpc_begin()`.
+ """
+ try:
+ with self.lock:
+ self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Commit a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("COMMIT", xid))
+
+ def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ """
+ Roll back a prepared two-phase transaction.
+ """
+ with self.lock:
+ self.wait(self._tpc_finish_gen("ROLLBACK", xid))
+
+ def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ cur.execute(Xid._get_recover_query())
+ res = cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py
new file mode 100644
index 0000000..aa02dc0
--- /dev/null
+++ b/psycopg/psycopg/connection_async.py
@@ -0,0 +1,436 @@
+"""
+psycopg async connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import asyncio
+import logging
+from types import TracebackType
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from . import waiting
+from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from ._tpc import Xid
+from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from ._pipeline import AsyncPipeline
+from ._encodings import pgconn_encoding
+from .connection import BaseConnection, CursorRow, Notify
+from .generators import notifies
+from .transaction import AsyncTransaction
+from .cursor_async import AsyncCursor
+from .server_cursor import AsyncServerCursor
+
+if TYPE_CHECKING:
+ from .pq.abc import PGconn
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class AsyncConnection(BaseConnection[Row]):
+ """
+ Asynchronous wrapper for a connection to the database.
+ """
+
+ __module__ = "psycopg"
+
+ cursor_factory: Type[AsyncCursor[Row]]
+ server_cursor_factory: Type[AsyncServerCursor[Row]]
+ row_factory: AsyncRowFactory[Row]
+ _pipeline: Optional[AsyncPipeline]
+ _Self = TypeVar("_Self", bound="AsyncConnection[Any]")
+
+ def __init__(
+ self,
+ pgconn: "PGconn",
+ row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
+ ):
+ super().__init__(pgconn)
+ self.row_factory = row_factory
+ self.lock = asyncio.Lock()
+ self.cursor_factory = AsyncCursor
+ self.server_cursor_factory = AsyncServerCursor
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[Row]":
+ # TODO: returned type should be _Self. See #308.
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncConnection[TupleRow]":
+ ...
+
+ @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ context: Optional[AdaptContext] = None,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
+ **kwargs: Any,
+ ) -> "AsyncConnection[Any]":
+
+ if sys.platform == "win32":
+ loop = asyncio.get_running_loop()
+ if isinstance(loop, asyncio.ProactorEventLoop):
+ raise e.InterfaceError(
+ "Psycopg cannot use the 'ProactorEventLoop' to run in async"
+ " mode. Please use a compatible event loop, for instance by"
+ " setting 'asyncio.set_event_loop_policy"
+ "(WindowsSelectorEventLoopPolicy())'"
+ )
+
+ params = await cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
+ try:
+ rv = await cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit),
+ timeout=params["connect_timeout"],
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ if row_factory:
+ rv.row_factory = row_factory
+ if cursor_factory:
+ rv.cursor_factory = cursor_factory
+ if context:
+ rv._adapters = AdaptersMap(context.adapters)
+ rv.prepare_threshold = prepare_threshold
+ return rv
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ if self.closed:
+ return
+
+ if exc_type:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ await self.rollback()
+ except Exception as exc2:
+ logger.warning(
+ "error ignored in rollback on %s: %s",
+ self,
+ exc2,
+ )
+ else:
+ await self.commit()
+
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ await self.close()
+
+ @classmethod
+ async def _get_connection_params(
+ cls, conninfo: str, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Manipulate connection parameters before connecting.
+
+ .. versionchanged:: 3.1
+ Unlike the sync counterpart, perform non-blocking address
+ resolution and populate the ``hostaddr`` connection parameter,
+ unless the user has provided one themselves. See
+ `~psycopg._dns.resolve_hostaddr_async()` for details.
+
+ """
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ # Resolve host addresses in non-blocking way
+ params = await resolve_hostaddr_async(params)
+
+ return params
+
+ async def close(self) -> None:
+ if self.closed:
+ return
+ self._closed = True
+ self.pgconn.finish()
+
+ @overload
+ def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
+ ) -> AsyncCursor[CursorRow]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[Row]:
+ ...
+
+ @overload
+ def cursor(
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: AsyncRowFactory[CursorRow],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> AsyncServerCursor[CursorRow]:
+ ...
+
+ def cursor(
+ self,
+ name: str = "",
+ *,
+ binary: bool = False,
+ row_factory: Optional[AsyncRowFactory[Any]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
+ """
+ Return a new `AsyncCursor` to send commands and queries to the connection.
+ """
+ self._check_connection_ok()
+
+ if not row_factory:
+ row_factory = self.row_factory
+
+ cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
+ if name:
+ cur = self.server_cursor_factory(
+ self,
+ name=name,
+ row_factory=row_factory,
+ scrollable=scrollable,
+ withhold=withhold,
+ )
+ else:
+ cur = self.cursor_factory(self, row_factory=row_factory)
+
+ if binary:
+ cur.format = BINARY
+
+ return cur
+
+ async def execute(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: bool = False,
+ ) -> AsyncCursor[Row]:
+ try:
+ cur = self.cursor()
+ if binary:
+ cur.format = BINARY
+
+ return await cur.execute(query, params, prepare=prepare)
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def commit(self) -> None:
+ async with self.lock:
+ await self.wait(self._commit_gen())
+
+ async def rollback(self) -> None:
+ async with self.lock:
+ await self.wait(self._rollback_gen())
+
+ @asynccontextmanager
+ async def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> AsyncIterator[AsyncTransaction]:
+ """
+ Start a context block with a new transaction or nested transaction.
+
+ :rtype: AsyncTransaction
+ """
+ tx = AsyncTransaction(self, savepoint_name, force_rollback)
+ if self._pipeline:
+ async with self.pipeline(), tx, self.pipeline():
+ yield tx
+ else:
+ async with tx:
+ yield tx
+
+ async def notifies(self) -> AsyncGenerator[Notify, None]:
+ while True:
+ async with self.lock:
+ try:
+ ns = await self.wait(notifies(self.pgconn))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ enc = pgconn_encoding(self.pgconn)
+ for pgn in ns:
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+ yield n
+
+ @asynccontextmanager
+ async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
+ """Context manager to switch the connection into pipeline mode."""
+ async with self.lock:
+ self._check_connection_ok()
+
+ pipeline = self._pipeline
+ if pipeline is None:
+ # WARNING: reference loop, broken ahead.
+ pipeline = self._pipeline = AsyncPipeline(self)
+
+ try:
+ async with pipeline:
+ yield pipeline
+ finally:
+ if pipeline.level == 0:
+ async with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
+
+ async def wait(self, gen: PQGen[RV]) -> RV:
+ try:
+ return await waiting.wait_async(gen, self.pgconn.socket)
+ except KeyboardInterrupt:
+ # TODO: this doesn't seem to work as it does for sync connections
+ # see tests/test_concurrency_async.py::test_ctrl_c
+ # In the test, the code doesn't reach this branch.
+
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # otherwise the connection will be stuck in ACTIVE state
+ c = self.pgconn.get_cancel()
+ c.cancel()
+ try:
+ await waiting.wait_async(gen, self.pgconn.socket)
+ except e.QueryCanceled:
+ pass # as expected
+ raise
+
+ @classmethod
+ async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
+ return await waiting.wait_conn_async(gen, timeout)
+
+ def _set_autocommit(self, value: bool) -> None:
+ self._no_set_async("autocommit")
+
+ async def set_autocommit(self, value: bool) -> None:
+ """Async version of the `~Connection.autocommit` setter."""
+ async with self.lock:
+ await self.wait(self._set_autocommit_gen(value))
+
+ def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ self._no_set_async("isolation_level")
+
+ async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+ """Async version of the `~Connection.isolation_level` setter."""
+ async with self.lock:
+ await self.wait(self._set_isolation_level_gen(value))
+
+ def _set_read_only(self, value: Optional[bool]) -> None:
+ self._no_set_async("read_only")
+
+ async def set_read_only(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.read_only` setter."""
+ async with self.lock:
+ await self.wait(self._set_read_only_gen(value))
+
+ def _set_deferrable(self, value: Optional[bool]) -> None:
+ self._no_set_async("deferrable")
+
+ async def set_deferrable(self, value: Optional[bool]) -> None:
+ """Async version of the `~Connection.deferrable` setter."""
+ async with self.lock:
+ await self.wait(self._set_deferrable_gen(value))
+
+ def _no_set_async(self, attribute: str) -> None:
+ raise AttributeError(
+ f"'the {attribute!r} property is read-only on async connections:"
+ f" please use 'await .set_{attribute}()' instead."
+ )
+
+ async def tpc_begin(self, xid: Union[Xid, str]) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_begin_gen(xid))
+
+ async def tpc_prepare(self) -> None:
+ try:
+ async with self.lock:
+ await self.wait(self._tpc_prepare_gen())
+ except e.ObjectNotInPrerequisiteState as ex:
+ raise e.NotSupportedError(str(ex)) from None
+
+ async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("commit", xid))
+
+ async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
+ async with self.lock:
+ await self.wait(self._tpc_finish_gen("rollback", xid))
+
+ async def tpc_recover(self) -> List[Xid]:
+ self._check_tpc()
+ status = self.info.transaction_status
+ async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
+ await cur.execute(Xid._get_recover_query())
+ res = await cur.fetchall()
+
+ if status == IDLE and self.info.transaction_status == INTRANS:
+ await self.rollback()
+
+ return res
diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py
new file mode 100644
index 0000000..3b21f83
--- /dev/null
+++ b/psycopg/psycopg/conninfo.py
@@ -0,0 +1,378 @@
+"""
+Functions to manipulate conninfo strings
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import socket
+import asyncio
+from typing import Any, Dict, List, Optional
+from pathlib import Path
+from datetime import tzinfo
+from functools import lru_cache
+from ipaddress import ip_address
+
+from . import pq
+from . import errors as e
+from ._tz import get_tzinfo
+from ._encodings import pgconn_encoding
+
+
+def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+ """
+ Merge a string and keyword params into a single conninfo string.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: A connection string valid for PostgreSQL, with the `!kwargs`
+ parameters merged.
+
+ Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
+ conninfo string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ if not conninfo and not kwargs:
+ return ""
+
+ # If no kwarg specified don't mung the conninfo but check if it's correct.
+ # Make sure to return a string, not a subtype, to avoid making Liskov sad.
+ if not kwargs:
+ _parse_conninfo(conninfo)
+ return str(conninfo)
+
+ # Override the conninfo with the parameters
+ # Drop the None arguments
+ kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
+
+ if conninfo:
+ tmp = conninfo_to_dict(conninfo)
+ tmp.update(kwargs)
+ kwargs = tmp
+
+ conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
+
+ # Verify the result is valid
+ _parse_conninfo(conninfo)
+
+ return conninfo
+
+
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+ """
+ Convert the `!conninfo` string into a dictionary of parameters.
+
+ :param conninfo: A `connection string`__ as accepted by PostgreSQL.
+ :param kwargs: Parameters overriding the ones specified in `!conninfo`.
+ :return: Dictionary with the parameters parsed from `!conninfo` and
+ `!kwargs`.
+
+ Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
+ string.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-CONNSTRING
+ """
+ opts = _parse_conninfo(conninfo)
+ rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+ for k, v in kwargs.items():
+ if v is not None:
+ rv[k] = v
+ return rv
+
+
+def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+ """
+ Verify that `!conninfo` is a valid connection string.
+
+ Raise ProgrammingError if the string is not valid.
+
+ Return the result of pq.Conninfo.parse() on success.
+ """
+ try:
+ return pq.Conninfo.parse(conninfo.encode())
+ except e.OperationalError as ex:
+ raise e.ProgrammingError(str(ex))
+
+
+re_escape = re.compile(r"([\\'])")
+re_space = re.compile(r"\s")
+
+
+def _param_escape(s: str) -> str:
+ """
+ Apply the escaping rule required by PQconnectdb
+ """
+ if not s:
+ return "''"
+
+ s = re_escape.sub(r"\\\1", s)
+ if re_space.search(s):
+ s = "'" + s + "'"
+
+ return s
+
+
+class ConnectionInfo:
+ """Allow access to information about the connection."""
+
+ __module__ = "psycopg"
+
+ def __init__(self, pgconn: pq.abc.PGconn):
+ self.pgconn = pgconn
+
+ @property
+ def vendor(self) -> str:
+ """A string representing the database vendor connected to."""
+ return "PostgreSQL"
+
+ @property
+ def host(self) -> str:
+ """The server host name of the active connection. See :pq:`PQhost()`."""
+ return self._get_pgconn_attr("host")
+
+ @property
+ def hostaddr(self) -> str:
+ """The server IP address of the connection. See :pq:`PQhostaddr()`."""
+ return self._get_pgconn_attr("hostaddr")
+
+ @property
+ def port(self) -> int:
+ """The port of the active connection. See :pq:`PQport()`."""
+ return int(self._get_pgconn_attr("port"))
+
+ @property
+ def dbname(self) -> str:
+ """The database name of the connection. See :pq:`PQdb()`."""
+ return self._get_pgconn_attr("db")
+
+ @property
+ def user(self) -> str:
+ """The user name of the connection. See :pq:`PQuser()`."""
+ return self._get_pgconn_attr("user")
+
+ @property
+ def password(self) -> str:
+ """The password of the connection. See :pq:`PQpass()`."""
+ return self._get_pgconn_attr("password")
+
+ @property
+ def options(self) -> str:
+ """
+ The command-line options passed in the connection request.
+ See :pq:`PQoptions`.
+ """
+ return self._get_pgconn_attr("options")
+
+ def get_parameters(self) -> Dict[str, str]:
+ """Return the connection parameters values.
+
+ Return all the parameters set to a non-default value, which might come
+ either from the connection string and parameters passed to
+ `~Connection.connect()` or from environment variables. The password
+ is never returned (you can read it using the `password` attribute).
+ """
+ pyenc = self.encoding
+
+ # Get the known defaults to avoid reporting them
+ defaults = {
+ i.keyword: i.compiled
+ for i in pq.Conninfo.get_defaults()
+ if i.compiled is not None
+ }
+ # Not returned by the libq. Bug? Bet we're using SSH.
+ defaults.setdefault(b"channel_binding", b"prefer")
+ defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
+
+ return {
+ i.keyword.decode(pyenc): i.val.decode(pyenc)
+ for i in self.pgconn.info
+ if i.val is not None
+ and i.keyword != b"password"
+ and i.val != defaults.get(i.keyword)
+ }
+
+ @property
+ def dsn(self) -> str:
+ """Return the connection string to connect to the database.
+
+ The string contains all the parameters set to a non-default value,
+ which might come either from the connection string and parameters
+ passed to `~Connection.connect()` or from environment variables. The
+ password is never returned (you can read it using the `password`
+ attribute).
+ """
+ return make_conninfo(**self.get_parameters())
+
+ @property
+ def status(self) -> pq.ConnStatus:
+ """The status of the connection. See :pq:`PQstatus()`."""
+ return pq.ConnStatus(self.pgconn.status)
+
+ @property
+ def transaction_status(self) -> pq.TransactionStatus:
+ """
+ The current in-transaction status of the session.
+ See :pq:`PQtransactionStatus()`.
+ """
+ return pq.TransactionStatus(self.pgconn.transaction_status)
+
+ @property
+ def pipeline_status(self) -> pq.PipelineStatus:
+ """
+ The current pipeline status of the client.
+ See :pq:`PQpipelineStatus()`.
+ """
+ return pq.PipelineStatus(self.pgconn.pipeline_status)
+
+ def parameter_status(self, param_name: str) -> Optional[str]:
+ """
+ Return a parameter setting of the connection.
+
+ Return `None` is the parameter is unknown.
+ """
+ res = self.pgconn.parameter_status(param_name.encode(self.encoding))
+ return res.decode(self.encoding) if res is not None else None
+
+ @property
+ def server_version(self) -> int:
+ """
+ An integer representing the server version. See :pq:`PQserverVersion()`.
+ """
+ return self.pgconn.server_version
+
+ @property
+ def backend_pid(self) -> int:
+ """
+ The process ID (PID) of the backend process handling this connection.
+ See :pq:`PQbackendPID()`.
+ """
+ return self.pgconn.backend_pid
+
+ @property
+ def error_message(self) -> str:
+ """
+ The error message most recently generated by an operation on the connection.
+ See :pq:`PQerrorMessage()`.
+ """
+ return self._get_pgconn_attr("error_message")
+
+ @property
+ def timezone(self) -> tzinfo:
+ """The Python timezone info of the connection's timezone."""
+ return get_tzinfo(self.pgconn)
+
+ @property
+ def encoding(self) -> str:
+ """The Python codec name of the connection's client encoding."""
+ return pgconn_encoding(self.pgconn)
+
+ def _get_pgconn_attr(self, name: str) -> str:
+ value: bytes = getattr(self.pgconn, name)
+ return value.decode(self.encoding)
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups, eventually removing hosts that are
+ not resolved, keeping the lists of hosts and ports consistent.
+
+ Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
+ host resolve, inconsistent lists length).
+ """
+ hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+ if hostaddr_arg:
+ # Already resolved
+ return params
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ if not host_arg:
+ # Nothing to resolve
+ return params
+
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",") if port_arg else []
+ default_port = "5432"
+
+ if len(ports_in) == 1:
+ # If only one port is specified, the libpq will apply it to all
+ # the hosts, so don't mangle it.
+ default_port = ports_in.pop()
+
+ elif len(ports_in) > 1:
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
+ )
+ ports_out = []
+
+ hosts_out = []
+ hostaddr_out = []
+ loop = asyncio.get_running_loop()
+ for i, host in enumerate(hosts_in):
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path
+ hosts_out.append(host)
+ hostaddr_out.append("")
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ # If the host is already an ip address don't try to resolve it
+ if is_ip_address(host):
+ hosts_out.append(host)
+ hostaddr_out.append(host)
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ try:
+ port = ports_in[i] if ports_in else default_port
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ except OSError as ex:
+ last_exc = ex
+ else:
+ for item in ans:
+ hosts_out.append(host)
+ hostaddr_out.append(item[4][0])
+ if ports_in:
+ ports_out.append(ports_in[i])
+
+ # Throw an exception if no host could be resolved
+ if not hosts_out:
+ raise e.OperationalError(str(last_exc))
+
+ out = params.copy()
+ out["host"] = ",".join(hosts_out)
+ out["hostaddr"] = ",".join(hostaddr_out)
+ if ports_in:
+ out["port"] = ",".join(ports_out)
+
+ return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py
new file mode 100644
index 0000000..7514306
--- /dev/null
+++ b/psycopg/psycopg/copy.py
@@ -0,0 +1,904 @@
+"""
+psycopg copy support
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import queue
+import struct
+import asyncio
+import threading
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
+from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import Buffer, ConnectionType, PQGen, Transformer
+from ._compat import create_task
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding
+from .generators import copy_from, copy_to, copy_end
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from .connection import Connection # noqa: F401
+ from .connection_async import AsyncConnection # noqa: F401
+
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
+
+class BaseCopy(Generic[ConnectionType]):
+ """
+ Base implementation for the copy user interface.
+
+ Two subclasses expose real methods with the sync/async differences.
+
+ The difference between the text and binary format is managed by two
+ different `Formatter` subclasses.
+
+ Writing (the I/O part) is implemented in the subclasses by a `Writer` or
+ `AsyncWriter` instance. Normally writing implies sending copy data to a
+ database, but a different writer might be chosen, e.g. to stream data into
+ a file for later use.
+ """
+
+ _Self = TypeVar("_Self", bound="BaseCopy[Any]")
+
+ formatter: "Formatter"
+
+ def __init__(
+ self,
+ cursor: "BaseCursor[ConnectionType, Any]",
+ *,
+ binary: Optional[bool] = None,
+ ):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ result = cursor.pgresult
+ if result:
+ self._direction = result.status
+ if self._direction != COPY_IN and self._direction != COPY_OUT:
+ raise e.ProgrammingError(
+ "the cursor should have performed a COPY operation;"
+ f" its status is {pq.ExecStatus(self._direction).name} instead"
+ )
+ else:
+ self._direction = COPY_IN
+
+ if binary is None:
+ binary = bool(result and result.binary_tuples)
+
+ tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+ if binary:
+ self.formatter = BinaryFormatter(tx)
+ else:
+ self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+
+ self._finished = False
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def _enter(self) -> None:
+ if self._finished:
+ raise TypeError("copy blocks can be used only once")
+
+ def set_types(self, types: Sequence[Union[int, str]]) -> None:
+ """
+ Set the types expected in a COPY operation.
+
+ The types must be specified as a sequence of oid or PostgreSQL type
+ names (e.g. ``int4``, ``timestamptz[]``).
+
+ This operation overcomes the lack of metadata returned by PostgreSQL
+ when a COPY operation begins:
+
+ - On :sql:`COPY TO`, `!set_types()` allows to specify what types the
+ operation returns. If `!set_types()` is not used, the data will be
+ returned as unparsed strings or bytes instead of Python objects.
+
+ - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
+ database expects. This is especially useful in binary copy, because
+ PostgreSQL will apply no cast rule.
+
+ """
+ registry = self.cursor.adapters.types
+ oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
+
+ if self._direction == COPY_IN:
+ self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
+ else:
+ self.formatter.transformer.set_loader_types(oids, self.formatter.format)
+
+ # High level copy protocol generators (state change of the Copy object)
+
+ def _read_gen(self) -> PQGen[Buffer]:
+ if self._finished:
+ return memoryview(b"")
+
+ res = yield from copy_from(self._pgconn)
+ if isinstance(res, memoryview):
+ return res
+
+ # res is the final PGresult
+ self._finished = True
+
+ # This result is a COMMAND_OK which has info about the number of rows
+ # returned, but not about the columns, which is instead an information
+ # that was received on the COPY_OUT result at the beginning of COPY.
+ # So, don't replace the results in the cursor, just update the rowcount.
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+ return memoryview(b"")
+
+ def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
+ data = yield from self._read_gen()
+ if not data:
+ return None
+
+ row = self.formatter.parse_row(data)
+ if row is None:
+ # Get the final result to finish the copy operation
+ yield from self._read_gen()
+ self._finished = True
+ return None
+
+ return row
+
+ def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ if not exc:
+ return
+
+ if self._pgconn.transaction_status != ACTIVE:
+ # The server has already finished to send copy data. The connection
+ # is already in a good state.
+ return
+
+ # Throw a cancel to the server, then consume the rest of the copy data
+ # (which might or might not have been already transferred entirely to
+ # the client, so we won't necessary see the exception associated with
+ # canceling).
+ self.connection.cancel()
+ try:
+ while (yield from self._read_gen()):
+ pass
+ except e.QueryCanceled:
+ pass
+
+
+class Copy(BaseCopy["Connection[Any]"]):
+ """Manage a :sql:`COPY` operation.
+
+ :param cursor: the cursor where the operation is performed.
+ :param binary: if `!True`, write binary format.
+ :param writer: the object to write to destination. If not specified, write
+ to the `!cursor` connection.
+
+ Choosing `!binary` is not necessary if the cursor has executed a
+ :sql:`COPY` operation, because the operation result describes the format
+ too. The parameter is useful when a `!Copy` object is created manually and
+ no operation is performed on the cursor, such as when using ``writer=``\\
+ `~psycopg.copy.FileWriter`.
+
+ """
+
+ __module__ = "psycopg"
+
+ writer: "Writer"
+
+ def __init__(
+ self,
+ cursor: "Cursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["Writer"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+ if not writer:
+ writer = LibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.finish(exc_val)
+
+ # End user sync interface
+
+ def __iter__(self) -> Iterator[Buffer]:
+ """Implement block-by-block iteration on :sql:`COPY TO`."""
+ while True:
+ data = self.read()
+ if not data:
+ break
+ yield data
+
+ def read(self) -> Buffer:
+ """
+ Read an unparsed row after a :sql:`COPY TO` operation.
+
+ Return an empty string when the data is finished.
+ """
+ return self.connection.wait(self._read_gen())
+
+ def rows(self) -> Iterator[Tuple[Any, ...]]:
+ """
+ Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ while True:
+ record = self.read_row()
+ if record is None:
+ break
+ yield record
+
+ def read_row(self) -> Optional[Tuple[Any, ...]]:
+ """
+ Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+ Return `!None` when the data is finished.
+
+ Note that the records returned will be tuples of unparsed strings or
+ bytes, unless data types are specified using `set_types()`.
+ """
+ return self.connection.wait(self._read_row_gen())
+
+ def write(self, buffer: Union[Buffer, str]) -> None:
+ """
+ Write a block of data to a table after a :sql:`COPY FROM` operation.
+
+ If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
+ text mode it can be either `!bytes` or `!str`.
+ """
+ data = self.formatter.write(buffer)
+ if data:
+ self._write(data)
+
+ def write_row(self, row: Sequence[Any]) -> None:
+ """Write a record to a table after a :sql:`COPY FROM` operation."""
+ data = self.formatter.write_row(row)
+ if data:
+ self._write(data)
+
+ def finish(self, exc: Optional[BaseException]) -> None:
+ """Terminate the copy operation and free the resources allocated.
+
+ You shouldn't need to call this function yourself: it is usually called
+ by exit. It is available if, despite what is documented, you end up
+ using the `Copy` object outside a block.
+ """
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ self._write(data)
+ self.writer.finish(exc)
+ self._finished = True
+ else:
+ self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class Writer(ABC):
+ """
+ A class to write copy data somewhere.
+ """
+
+ @abstractmethod
+ def write(self, data: Buffer) -> None:
+ """
+ Write some data to destination.
+ """
+ ...
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ """
+ Called when write operations are finished.
+
+ If operations finished with an error, it will be passed to ``exc``.
+ """
+ pass
+
+
+class LibpqWriter(Writer):
+ """
+ A `Writer` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class QueuedLibpqDriver(LibpqWriter):
+ """
+ A writer using a buffer to queue data to write to a Postgres database.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "Cursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[threading.Thread] = None
+ self._worker_error: Optional[BaseException] = None
+
+ def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value, or in case
+ of error.
+
+ The function is designed to be run in a separate thread.
+ """
+ try:
+ while True:
+ data = self._queue.get(block=True, timeout=24 * 60 * 60)
+ if not data:
+ break
+ self.connection.wait(copy_to(self._pgconn, data))
+ except BaseException as ex:
+ # Propagate the error to the main thread.
+ self._worker_error = ex
+
+ def write(self, data: Buffer) -> None:
+ if not self._worker:
+ # warning: reference loop, broken by _write_end
+ self._worker = threading.Thread(target=self.worker)
+ self._worker.daemon = True
+ self._worker.start()
+
+ # If the worker thread raies an exception, re-raise it to the caller.
+ if self._worker_error:
+ raise self._worker_error
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ self._queue.put(b"")
+
+ if self._worker:
+ self._worker.join()
+ self._worker = None # break the loop
+
+ # Check if the worker thread raised any exception before terminating.
+ if self._worker_error:
+ raise self._worker_error
+
+ super().finish(exc)
+
+
+class FileWriter(Writer):
+ """
+ A `Writer` to write copy data to a file-like object.
+
+ :param file: the file where to write copy data. It must be open for writing
+ in binary mode.
+ """
+
+ def __init__(self, file: IO[bytes]):
+ self.file = file
+
+ def write(self, data: Buffer) -> None:
+ self.file.write(data) # type: ignore[arg-type]
+
+
+class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
+ """Manage an asynchronous :sql:`COPY` operation."""
+
+ __module__ = "psycopg"
+
+ writer: "AsyncWriter"
+
+ def __init__(
+ self,
+ cursor: "AsyncCursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["AsyncWriter"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
+
+ if not writer:
+ writer = AsyncLibpqWriter(cursor)
+
+ self.writer = writer
+ self._write = writer.write
+
+ async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
+ self._enter()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.finish(exc_val)
+
+ async def __aiter__(self) -> AsyncIterator[Buffer]:
+ while True:
+ data = await self.read()
+ if not data:
+ break
+ yield data
+
+ async def read(self) -> Buffer:
+ return await self.connection.wait(self._read_gen())
+
+ async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
+ while True:
+ record = await self.read_row()
+ if record is None:
+ break
+ yield record
+
+ async def read_row(self) -> Optional[Tuple[Any, ...]]:
+ return await self.connection.wait(self._read_row_gen())
+
+ async def write(self, buffer: Union[Buffer, str]) -> None:
+ data = self.formatter.write(buffer)
+ if data:
+ await self._write(data)
+
+ async def write_row(self, row: Sequence[Any]) -> None:
+ data = self.formatter.write_row(row)
+ if data:
+ await self._write(data)
+
+ async def finish(self, exc: Optional[BaseException]) -> None:
+ if self._direction == COPY_IN:
+ data = self.formatter.end()
+ if data:
+ await self._write(data)
+ await self.writer.finish(exc)
+ self._finished = True
+ else:
+ await self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class AsyncWriter(ABC):
+ """
+ A class to write copy data somewhere (for async connections).
+ """
+
+ @abstractmethod
+ async def write(self, data: Buffer) -> None:
+ ...
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ pass
+
+
+class AsyncLibpqWriter(AsyncWriter):
+ """
+ An `AsyncWriter` to write copy data to a Postgres database.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
+ self._pgconn = self.connection.pgconn
+
+ async def write(self, data: Buffer) -> None:
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self.connection.wait(copy_to(self._pgconn, data))
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self.connection.wait(
+ copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+ )
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+ self.cursor._results = [res]
+
+
+class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
+ """
+ An `AsyncWriter` using a buffer to queue data to write.
+
+ `write()` returns immediately, so that the main thread can be CPU-bound
+ formatting messages, while a worker thread can be IO-bound waiting to write
+ on the connection.
+ """
+
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ super().__init__(cursor)
+
+ self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
+ self._worker: Optional[asyncio.Future[None]] = None
+
+ async def worker(self) -> None:
+ """Push data to the server when available from the copy queue.
+
+ Terminate reading when the queue receives a false-y value.
+
+ The function is designed to be run in a separate task.
+ """
+ while True:
+ data = await self._queue.get()
+ if not data:
+ break
+ await self.connection.wait(copy_to(self._pgconn, data))
+
+ async def write(self, data: Buffer) -> None:
+ if not self._worker:
+ self._worker = create_task(self.worker())
+
+ if len(data) <= MAX_BUFFER_SIZE:
+ # Most used path: we don't need to split the buffer in smaller
+ # bits, so don't make a copy.
+ await self._queue.put(data)
+ else:
+ # Copy a buffer too large in chunks to avoid causing a memory
+ # error in the libpq, which may cause an infinite loop (#255).
+ for i in range(0, len(data), MAX_BUFFER_SIZE):
+ await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ await self._queue.put(b"")
+
+ if self._worker:
+ await asyncio.gather(self._worker)
+ self._worker = None # break reference loops if any
+
+ await super().finish(exc)
+
+
+class Formatter(ABC):
+ """
+ A class which understand a copy format (text, binary).
+ """
+
+ format: pq.Format
+
+ def __init__(self, transformer: Transformer):
+ self.transformer = transformer
+ self._write_buffer = bytearray()
+ self._row_mode = False # true if the user is using write_row()
+
+ @abstractmethod
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ ...
+
+ @abstractmethod
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ ...
+
+ @abstractmethod
+ def end(self) -> Buffer:
+ ...
+
+
+class TextFormatter(Formatter):
+
+ format = TEXT
+
+ def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+ super().__init__(transformer)
+ self._encoding = encoding
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if data:
+ return parse_row_text(data, self.transformer)
+ else:
+ return None
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ format_row_text(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ return data.encode(self._encoding)
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+class BinaryFormatter(Formatter):
+
+ format = BINARY
+
+ def __init__(self, transformer: Transformer):
+ super().__init__(transformer)
+ self._signature_sent = False
+
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+ if not self._signature_sent:
+ if data[: len(_binary_signature)] != _binary_signature:
+ raise e.DataError(
+ "binary copy doesn't start with the expected signature"
+ )
+ self._signature_sent = True
+ data = data[len(_binary_signature) :]
+
+ elif data == _binary_trailer:
+ return None
+
+ return parse_row_binary(data, self.transformer)
+
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
+ data = self._ensure_bytes(buffer)
+ self._signature_sent = True
+ return data
+
+ def write_row(self, row: Sequence[Any]) -> Buffer:
+ # Note down that we are writing in row mode: it means we will have
+ # to take care of the end-of-copy marker too
+ self._row_mode = True
+
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._signature_sent = True
+
+ format_row_binary(row, self.transformer, self._write_buffer)
+ if len(self._write_buffer) > BUFFER_SIZE:
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+ else:
+ return b""
+
+ def end(self) -> Buffer:
+ # If we have sent no data we need to send the signature
+ # and the trailer
+ if not self._signature_sent:
+ self._write_buffer += _binary_signature
+ self._write_buffer += _binary_trailer
+
+ elif self._row_mode:
+ # if we have sent data already, we have sent the signature
+ # too (either with the first row, or we assume that in
+ # block mode the signature is included).
+ # Write the trailer only if we are sending rows (with the
+ # assumption that who is copying binary data is sending the
+ # whole format).
+ self._write_buffer += _binary_trailer
+
+ buffer, self._write_buffer = self._write_buffer, bytearray()
+ return buffer
+
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
+ raise TypeError("cannot copy str data in binary mode: use bytes instead")
+ else:
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
+
+
+def _format_row_text(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for copy."""
+ if out is None:
+ out = bytearray()
+
+ if not row:
+ out += b"\n"
+ return out
+
+ for item in row:
+ if item is not None:
+ dumper = tx.get_dumper(item, PY_TEXT)
+ b = dumper.dump(item)
+ out += _dump_re.sub(_dump_sub, b)
+ else:
+ out += rb"\N"
+ out += b"\t"
+
+ out[-1:] = b"\n"
+ return out
+
+
+def _format_row_binary(
+ row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+ """Convert a row of objects to the data to send for binary copy."""
+ if out is None:
+ out = bytearray()
+
+ out += _pack_int2(len(row))
+ adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
+ for b in adapted:
+ if b is not None:
+ out += _pack_int4(len(b))
+ out += b
+ else:
+ out += _binary_null
+
+ return out
+
+
+def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ fields = data.split(b"\t")
+ fields[-1] = fields[-1][:-1] # drop \n
+ row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
+ return tx.load_sequence(row)
+
+
+def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+ row: List[Optional[Buffer]] = []
+ nfields = _unpack_int2(data, 0)[0]
+ pos = 2
+ for i in range(nfields):
+ length = _unpack_int4(data, pos)[0]
+ pos += 4
+ if length >= 0:
+ row.append(data[pos : pos + length])
+ pos += length
+ else:
+ row.append(None)
+
+ return tx.load_sequence(row)
+
+
+_pack_int2 = struct.Struct("!h").pack
+_pack_int4 = struct.Struct("!i").pack
+_unpack_int2 = struct.Struct("!h").unpack_from
+_unpack_int4 = struct.Struct("!i").unpack_from
+
+_binary_signature = (
+ b"PGCOPY\n\xff\r\n\0" # Signature
+ b"\x00\x00\x00\x00" # flags
+ b"\x00\x00\x00\x00" # extra length
+)
+_binary_trailer = b"\xff\xff"
+_binary_null = b"\xff\xff\xff\xff"
+
+_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_dump_repl = {
+ b"\b": b"\\b",
+ b"\t": b"\\t",
+ b"\n": b"\\n",
+ b"\v": b"\\v",
+ b"\f": b"\\f",
+ b"\r": b"\\r",
+ b"\\": b"\\\\",
+}
+
+
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+_load_re = re.compile(b"\\\\[btnvfr\\\\]")
+_load_repl = {v: k for k, v in _dump_repl.items()}
+
+
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
+ return __map[m.group(0)]
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ format_row_text = _psycopg.format_row_text
+ format_row_binary = _psycopg.format_row_binary
+ parse_row_text = _psycopg.parse_row_text
+ parse_row_binary = _psycopg.parse_row_binary
+
+else:
+ format_row_text = _format_row_text
+ format_row_binary = _format_row_binary
+ parse_row_text = _parse_row_text
+ parse_row_binary = _parse_row_binary
diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py
new file mode 100644
index 0000000..323903a
--- /dev/null
+++ b/psycopg/psycopg/crdb/__init__.py
@@ -0,0 +1,19 @@
+"""
+CockroachDB support package.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from . import _types
+from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
+
+adapters = _types.adapters # exposed by the package
+connect = CrdbConnection.connect
+
+_types.register_crdb_adapters(adapters)
+
+__all__ = [
+ "AsyncCrdbConnection",
+ "CrdbConnection",
+ "CrdbConnectionInfo",
+]
diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py
new file mode 100644
index 0000000..5311e05
--- /dev/null
+++ b/psycopg/psycopg/crdb/_types.py
@@ -0,0 +1,163 @@
+"""
+Types configuration specific for CockroachDB.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from enum import Enum
+from .._typeinfo import TypeInfo, TypesRegistry
+
+from ..abc import AdaptContext, NoneType
+from ..postgres import TEXT_OID
+from .._adapters_map import AdaptersMap
+from ..types.enum import EnumDumper, EnumBinaryDumper
+from ..types.none import NoneDumper
+
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+
+class CrdbEnumDumper(EnumDumper):
+ oid = TEXT_OID
+
+
+class CrdbEnumBinaryDumper(EnumBinaryDumper):
+ oid = TEXT_OID
+
+
+class CrdbNoneDumper(NoneDumper):
+ oid = TEXT_OID
+
+
+def register_postgres_adapters(context: AdaptContext) -> None:
+ # Same adapters used by PostgreSQL, or a good starting point for customization
+
+ from ..types import array, bool, composite, datetime
+ from ..types import numeric, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
+
+
+def register_crdb_adapters(context: AdaptContext) -> None:
+ from .. import dbapi20
+ from ..types import array
+
+ register_postgres_adapters(context)
+
+ # String must come after enum to map text oid -> string dumper
+ register_crdb_enum_adapters(context)
+ register_crdb_string_adapters(context)
+ register_crdb_json_adapters(context)
+ register_crdb_net_adapters(context)
+ register_crdb_none_adapters(context)
+
+ dbapi20.register_dbapi20_adapters(adapters)
+
+ array.register_all_arrays(adapters)
+
+
+def register_crdb_string_adapters(context: AdaptContext) -> None:
+ from ..types import string
+
+ # Dump strings with text oid instead of unknown.
+ # Unlike PostgreSQL, CRDB seems able to cast text to most types.
+ context.adapters.register_dumper(str, string.StrDumper)
+ context.adapters.register_dumper(str, string.StrBinaryDumper)
+
+
+def register_crdb_enum_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
+ context.adapters.register_dumper(Enum, CrdbEnumDumper)
+
+
+def register_crdb_json_adapters(context: AdaptContext) -> None:
+ from ..types import json
+
+ adapters = context.adapters
+
+ # CRDB doesn't have json/jsonb: both names map to the jsonb oid
+ adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Json, json.JsonbDumper)
+
+ adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
+ adapters.register_dumper(json.Jsonb, json.JsonbDumper)
+
+ adapters.register_loader("json", json.JsonLoader)
+ adapters.register_loader("jsonb", json.JsonbLoader)
+ adapters.register_loader("json", json.JsonBinaryLoader)
+ adapters.register_loader("jsonb", json.JsonbBinaryLoader)
+
+
+def register_crdb_net_adapters(context: AdaptContext) -> None:
+ from ..types import net
+
+ adapters = context.adapters
+
+ adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
+ adapters.register_dumper(None, net.InetBinaryDumper)
+ adapters.register_loader("inet", net.InetLoader)
+ adapters.register_loader("inet", net.InetBinaryLoader)
+
+
+def register_crdb_none_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, CrdbNoneDumper)
+
+
+for t in [
+ TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
+ TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
+ TypeInfo('"char"', 18, 1002), # special case, not generated
+ # autogenerated: start
+ # Generated from CockroachDB 22.1.0
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("unknown", 705, 0),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ # autogenerated: end
+]:
+ types.add(t)
diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py
new file mode 100644
index 0000000..6e79ed1
--- /dev/null
+++ b/psycopg/psycopg/crdb/connection.py
@@ -0,0 +1,186 @@
+"""
+CockroachDB-specific connections.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import re
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
+
+from .. import errors as e
+from ..abc import AdaptContext
+from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
+from ..conninfo import ConnectionInfo
+from ..connection import Connection
+from .._adapters_map import AdaptersMap
+from ..connection_async import AsyncConnection
+from ._types import adapters
+
+if TYPE_CHECKING:
+ from ..pq.abc import PGconn
+ from ..cursor import Cursor
+ from ..cursor_async import AsyncCursor
+
+
+class _CrdbConnectionMixin:
+
+ _adapters: Optional[AdaptersMap]
+ pgconn: "PGconn"
+
+ @classmethod
+ def is_crdb(
+ cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
+ ) -> bool:
+ """
+ Return `!True` if the server connected to `!conn` is CockroachDB.
+ """
+ if isinstance(conn, (Connection, AsyncConnection)):
+ conn = conn.pgconn
+
+ return bool(conn.parameter_status(b"crdb_version"))
+
+ @property
+ def adapters(self) -> AdaptersMap:
+ if not self._adapters:
+ # By default, use CockroachDB adapters map
+ self._adapters = AdaptersMap(adapters)
+
+ return self._adapters
+
+ @property
+ def info(self) -> "CrdbConnectionInfo":
+ return CrdbConnectionInfo(self.pgconn)
+
+ def _check_tpc(self) -> None:
+ if self.is_crdb(self.pgconn):
+ raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
+
+
+class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
+ """
+ Wrapper for a connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+ """
+ Connect to a database server and return a new `CrdbConnection` instance.
+ """
+ return super().connect(conninfo, **kwargs) # type: ignore[return-value]
+
+
+class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
+ """
+ Wrapper for an async connection to a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ async def connect(
+ cls, conninfo: str = "", **kwargs: Any
+ ) -> "AsyncCrdbConnection[Any]":
+ return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
+
+
+class CrdbConnectionInfo(ConnectionInfo):
+ """
+ `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
+ """
+
+ __module__ = "psycopg.crdb"
+
+ @property
+ def vendor(self) -> str:
+ return "CockroachDB"
+
+ @property
+ def server_version(self) -> int:
+ """
+ Return the CockroachDB server version connected.
+
+ Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
+ """
+ sver = self.parameter_status("crdb_version")
+ if not sver:
+ raise e.InternalError("'crdb_version' parameter status not set")
+
+ ver = self.parse_crdb_version(sver)
+ if ver is None:
+ raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
+
+ return ver
+
+ @classmethod
+ def parse_crdb_version(self, sver: str) -> Optional[int]:
+ m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+ if not m:
+ return None
+
+ return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py
new file mode 100644
index 0000000..42c3804
--- /dev/null
+++ b/psycopg/psycopg/cursor.py
@@ -0,0 +1,921 @@
+"""
+psycopg cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from functools import partial
+from types import TracebackType
+from typing import Any, Generic, Iterable, Iterator, List
+from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
+from typing import overload, TYPE_CHECKING
+from contextlib import contextmanager
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .copy import Copy, Writer as CopyWriter
+from .rows import Row, RowMaker, RowFactory
+from ._column import Column
+from ._queries import PostgresQuery, PostgresClientQuery
+from ._pipeline import Pipeline
+from ._encodings import pgconn_encoding
+from ._preparing import Prepare
+from .generators import execute, fetch, send
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+ from .pq.abc import PGconn, PGresult
+ from .connection import Connection
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class BaseCursor(Generic[ConnectionType, Row]):
+ __slots__ = """
+ _conn format _adapters arraysize _closed _results pgresult _pos
+ _iresult _rowcount _query _tx _last_query _row_factory _make_row
+ _pgconn _execmany_returning
+ __weakref__
+ """.split()
+
+ ExecStatus = pq.ExecStatus
+
+ _tx: "Transformer"
+ _make_row: RowMaker[Row]
+ _pgconn: "PGconn"
+
+ def __init__(self, connection: ConnectionType):
+ self._conn = connection
+ self.format = TEXT
+ self._pgconn = connection.pgconn
+ self._adapters = adapt.AdaptersMap(connection.adapters)
+ self.arraysize = 1
+ self._closed = False
+ self._last_query: Optional[Query] = None
+ self._reset()
+
+ def _reset(self, reset_query: bool = True) -> None:
+ self._results: List["PGresult"] = []
+ self.pgresult: Optional["PGresult"] = None
+ self._pos = 0
+ self._iresult = 0
+ self._rowcount = -1
+ self._query: Optional[PostgresQuery]
+ # None if executemany() not executing, True/False according to returning state
+ self._execmany_returning: Optional[bool] = None
+ if reset_query:
+ self._query = None
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self._pgconn)
+ if self._closed:
+ status = "closed"
+ elif self.pgresult:
+ status = pq.ExecStatus(self.pgresult.status).name
+ else:
+ status = "no result"
+ return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
+
+ @property
+ def connection(self) -> ConnectionType:
+ """The connection this cursor is using."""
+ return self._conn
+
+ @property
+ def adapters(self) -> adapt.AdaptersMap:
+ return self._adapters
+
+ @property
+ def closed(self) -> bool:
+ """`True` if the cursor is closed."""
+ return self._closed
+
+ @property
+ def description(self) -> Optional[List[Column]]:
+ """
+ A list of `Column` objects describing the current resultset.
+
+ `!None` if the current resultset didn't return tuples.
+ """
+ res = self.pgresult
+
+ # We return columns if we have nfields, but also if we don't but
+ # the query said we got tuples (mostly to handle the super useful
+ # query "SELECT ;"
+ if res and (
+ res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
+ ):
+ return [Column(self, i) for i in range(res.nfields)]
+ else:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ """Number of records affected by the precedent operation."""
+ return self._rowcount
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+ return self._pos if tuples else None
+
+ def setinputsizes(self, sizes: Sequence[Any]) -> None:
+ # no-op
+ pass
+
+ def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
+ # no-op
+ pass
+
+ def nextset(self) -> Optional[bool]:
+ """
+ Move to the result set of the next query executed through `executemany()`
+ or to the next result set if `execute()` returned more than one.
+
+ Return `!True` if a new result is available, which will be the one
+ methods `!fetch*()` will operate on.
+ """
+ if self._iresult < len(self._results) - 1:
+ self._select_current_result(self._iresult + 1)
+ return True
+ else:
+ return None
+
+ @property
+ def statusmessage(self) -> Optional[str]:
+ """
+ The command status tag from the last SQL command executed.
+
+ `!None` if the cursor doesn't have a result available.
+ """
+ msg = self.pgresult.command_status if self.pgresult else None
+ return msg.decode() if msg else None
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ raise NotImplementedError
+
+ #
+ # Generators for the high level operations on the cursor
+ #
+ # Like for sync/async connections, these are implemented as generators
+ # so that different concurrency strategies (threads,asyncio) can use their
+ # own way of waiting (or better, `connection.wait()`).
+ #
+
+ def _execute_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `Cursor.execute()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ results = yield from self._maybe_prepare_gen(
+ pgq, prepare=prepare, binary=binary
+ )
+ if self._conn._pipeline:
+ yield from self._conn._pipeline._communicate_gen()
+ else:
+ assert results is not None
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0)
+
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines available.
+ """
+ pipeline = self._conn._pipeline
+ assert pipeline
+
+ yield from self._start_query(query)
+ self._rowcount = 0
+
+ assert self._execmany_returning is None
+ self._execmany_returning = returning
+
+ first = True
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ yield from self._maybe_prepare_gen(pgq, prepare=True)
+ yield from pipeline._communicate_gen()
+
+ self._last_query = query
+
+ if returning:
+ yield from pipeline._fetch_gen(flush=True)
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _executemany_gen_no_pipeline(
+ self, query: Query, params_seq: Iterable[Params], returning: bool
+ ) -> PQGen[None]:
+ """
+ Generator implementing `Cursor.executemany()` with pipelines not available.
+ """
+ yield from self._start_query(query)
+ first = True
+ nrows = 0
+ for params in params_seq:
+ if first:
+ pgq = self._convert_query(query, params)
+ self._query = pgq
+ first = False
+ else:
+ pgq.dump(params)
+
+ results = yield from self._maybe_prepare_gen(pgq, prepare=True)
+ assert results is not None
+ self._check_results(results)
+ if returning:
+ self._results.extend(results)
+
+ for res in results:
+ nrows += res.command_tuples or 0
+
+ if self._results:
+ self._select_current_result(0)
+
+ # Override rowcount for the first result. Calls to nextset() will change
+ # it to the value of that result only, but we hope nobody will notice.
+ # You haven't read this comment.
+ self._rowcount = nrows
+ self._last_query = query
+
+ for cmd in self._conn._prepared.get_maintenance_commands():
+ yield from self._conn._exec_command(cmd)
+
+ def _maybe_prepare_gen(
+ self,
+ pgq: PostgresQuery,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[Optional[List["PGresult"]]]:
+ # Check if the query is prepared or needs preparing
+ prep, name = self._get_prepared(pgq, prepare)
+ if prep is Prepare.NO:
+ # The query must be executed without preparing
+ self._execute_send(pgq, binary=binary)
+ else:
+ # If the query is not already prepared, prepare it.
+ if prep is Prepare.SHOULD:
+ self._send_prepare(name, pgq)
+ if not self._conn._pipeline:
+ (result,) = yield from execute(self._pgconn)
+ if result.status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ # Then execute it.
+ self._send_query_prepared(name, pgq, binary=binary)
+
+ # Update the prepare state of the query.
+ # If an operation requires to flush our prepared statements cache,
+ # it will be added to the maintenance commands to execute later.
+ key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+ if self._conn._pipeline:
+ queued = None
+ if key is not None:
+ queued = (key, prep, name)
+ self._conn._pipeline.result_queue.append((self, queued))
+ return None
+
+ # run the query
+ results = yield from execute(self._pgconn)
+
+ if key is not None:
+ self._conn._prepared.validate(key, prep, name, results)
+
+ return results
+
+ def _get_prepared(
+ self, pgq: PostgresQuery, prepare: Optional[bool] = None
+ ) -> Tuple[Prepare, bytes]:
+ return self._conn._prepared.get(pgq, prepare)
+
+ def _stream_send_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator to send the query for `Cursor.stream()`."""
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, binary=binary, force_extended=True)
+ self._pgconn.set_single_row_mode()
+ self._last_query = query
+ yield from send(self._pgconn)
+
+ def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
+ res = yield from fetch(self._pgconn)
+ if res is None:
+ return None
+
+ status = res.status
+ if status == SINGLE_TUPLE:
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=first)
+ if first:
+ self._make_row = self._make_row_maker()
+ return res
+
+ elif status == TUPLES_OK or status == COMMAND_OK:
+ # End of single row results
+ while res:
+ res = yield from fetch(self._pgconn)
+ if status != TUPLES_OK:
+ raise e.ProgrammingError(
+ "the operation in stream() didn't produce a result"
+ )
+ return None
+
+ else:
+ # Errors, unexpected values
+ return self._raise_for_result(res)
+
+ def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
+ """Generator to start the processing of a query.
+
+ It is implemented as generator because it may send additional queries,
+ such as `begin`.
+ """
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+
+ self._reset()
+ if not self._last_query or (self._last_query is not query):
+ self._last_query = None
+ self._tx = adapt.Transformer(self)
+ yield from self._conn._start_query()
+
+ def _start_copy_gen(
+ self, statement: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
+ """Generator implementing sending a command for `Cursor.copy()."""
+
+ # The connection gets in an unrecoverable state if we attempt COPY in
+ # pipeline mode. Forbid it explicitly.
+ if self._conn._pipeline:
+ raise e.NotSupportedError("COPY cannot be used in pipeline mode")
+
+ yield from self._start_query()
+
+ # Merge the params client-side
+ if params:
+ pgq = PostgresClientQuery(self._tx)
+ pgq.convert(statement, params)
+ statement = pgq.query
+
+ query = self._convert_query(statement)
+
+ self._execute_send(query, binary=False)
+ results = yield from execute(self._pgconn)
+ if len(results) != 1:
+ raise e.ProgrammingError("COPY cannot be mixed with other operations")
+
+ self._check_copy_result(results[0])
+ self._results = results
+ self._select_current_result(0)
+
+ def _execute_send(
+ self,
+ query: PostgresQuery,
+ *,
+ force_extended: bool = False,
+ binary: Optional[bool] = None,
+ ) -> None:
+ """
+ Implement part of execute() before waiting common to sync and async.
+
+ This is not a generator, but a normal non-blocking function.
+ """
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ self._query = query
+
+ if self._conn._pipeline:
+ # In pipeline mode always use PQsendQueryParams - see #314
+ # Multiple statements in the same query are not allowed anyway.
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_params,
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ )
+ elif force_extended or query.params or fmt == BINARY:
+ self._pgconn.send_query_params(
+ query.query,
+ query.params,
+ param_formats=query.formats,
+ param_types=query.types,
+ result_format=fmt,
+ )
+ else:
+ # If we can, let's use simple query protocol,
+ # as it can execute more than one statement in a single query.
+ self._pgconn.send_query(query.query)
+
+ def _convert_query(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PostgresQuery:
+ pgq = PostgresQuery(self._tx)
+ pgq.convert(query, params)
+ return pgq
+
+ def _check_results(self, results: List["PGresult"]) -> None:
+ """
+ Verify that the results of a query are valid.
+
+ Verify that the query returned at least one result and that they all
+ represent a valid result from the database.
+ """
+ if not results:
+ raise e.InternalError("got no result from the query")
+
+ for res in results:
+ status = res.status
+ if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
+ self._raise_for_result(res)
+
+ def _raise_for_result(self, result: "PGresult") -> NoReturn:
+ """
+ Raise an appropriate error message for an unexpected database result
+ """
+ status = result.status
+ if status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ raise e.ProgrammingError(
+ "COPY cannot be used with this method; use copy() instead"
+ )
+ else:
+ raise e.InternalError(
+ "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
+ )
+
+ def _select_current_result(
+ self, i: int, format: Optional[pq.Format] = None
+ ) -> None:
+ """
+ Select one of the results in the cursor as the active one.
+ """
+ self._iresult = i
+ res = self.pgresult = self._results[i]
+
+ # Note: the only reason to override format is to correctly set
+ # binary loaders on server-side cursors, because send_describe_portal
+ # only returns a text result.
+ self._tx.set_pgresult(res, format=format)
+
+ self._pos = 0
+
+ if res.status == TUPLES_OK:
+ self._rowcount = self.pgresult.ntuples
+
+ # COPY_OUT has never info about nrows. We need such result for the
+ # columns in order to return a `description`, but not overwrite the
+ # cursor rowcount (which was set by the Copy object).
+ elif res.status != COPY_OUT:
+ nrows = self.pgresult.command_tuples
+ self._rowcount = nrows if nrows is not None else -1
+
+ self._make_row = self._make_row_maker()
+
+ def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
+ self._check_results(results)
+ first_batch = not self._results
+
+ if self._execmany_returning is None:
+ # Received from execute()
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+
+ else:
+ # Received from executemany()
+ if self._execmany_returning:
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
+ self._rowcount = 0
+
+ # Override rowcount for the first result. Calls to nextset() will
+ # change it to the value of that result only, but we hope nobody
+ # will notice.
+ # You haven't read this comment.
+ if self._rowcount < 0:
+ self._rowcount = 0
+ for res in results:
+ self._rowcount += res.command_tuples or 0
+
+ def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_prepare,
+ name,
+ query.query,
+ param_types=query.types,
+ )
+ )
+ self._conn._pipeline.result_queue.append(None)
+ else:
+ self._pgconn.send_prepare(name, query.query, param_types=query.types)
+
+ def _send_query_prepared(
+ self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
+ ) -> None:
+ if binary is None:
+ fmt = self.format
+ else:
+ fmt = BINARY if binary else TEXT
+
+ if self._conn._pipeline:
+ self._conn._pipeline.command_queue.append(
+ partial(
+ self._pgconn.send_query_prepared,
+ name,
+ pgq.params,
+ param_formats=pgq.formats,
+ result_format=fmt,
+ )
+ )
+ else:
+ self._pgconn.send_query_prepared(
+ name, pgq.params, param_formats=pgq.formats, result_format=fmt
+ )
+
+ def _check_result_for_fetch(self) -> None:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ res = self.pgresult
+ if not res:
+ raise e.ProgrammingError("no result available")
+
+ status = res.status
+ if status == TUPLES_OK:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(res, encoding=self._encoding)
+ elif status == PIPELINE_ABORTED:
+ raise e.PipelineAborted("pipeline aborted")
+ else:
+ raise e.ProgrammingError("the last operation didn't produce a result")
+
+ def _check_copy_result(self, result: "PGresult") -> None:
+ """
+ Check that the value returned in a copy() operation is a legit COPY.
+ """
+ status = result.status
+ if status == COPY_IN or status == COPY_OUT:
+ return
+ elif status == FATAL_ERROR:
+ raise e.error_from_result(result, encoding=self._encoding)
+ else:
+ raise e.ProgrammingError(
+ "copy() should be used only with COPY ... TO STDOUT or COPY ..."
+ f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
+ )
+
+ def _scroll(self, value: int, mode: str) -> None:
+ self._check_result_for_fetch()
+ assert self.pgresult
+ if mode == "relative":
+ newpos = self._pos + value
+ elif mode == "absolute":
+ newpos = value
+ else:
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ if not 0 <= newpos < self.pgresult.ntuples:
+ raise IndexError("position out of bound")
+ self._pos = newpos
+
+ def _close(self) -> None:
+ """Non-blocking part of closing. Common to sync/async."""
+ # Don't reset the query because it may be useful to investigate after
+ # an error.
+ self._reset(reset_query=False)
+ self._closed = True
+
+ @property
+ def _encoding(self) -> str:
+ return pgconn_encoding(self._pgconn)
+
+
+class Cursor(BaseCursor["Connection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="Cursor[Any]")
+
+ @overload
+ def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "Cursor[Row]",
+ connection: "Connection[Any]",
+ *,
+ row_factory: RowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ def __enter__(self: _Self) -> _Self:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ self._close()
+
+ @property
+ def row_factory(self) -> RowFactory[Row]:
+ """Writable attribute to control how result rows are formed."""
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: RowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ """
+ Execute a query or command to the database.
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ """
+ Execute the same command with a sequence of input data.
+ """
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ with self._conn.pipeline(), self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ with self._conn.lock:
+ self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> Iterator[Row]:
+ """
+ Iterate row-by-row on a result from the database.
+ """
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ with self._conn.lock:
+
+ try:
+ self._conn.wait(self._stream_send_gen(query, params, binary=binary))
+ first = True
+ while self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while self._conn.wait(self._stream_fetchone_gen(first=False)):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ def fetchone(self) -> Optional[Row]:
+ """
+ Return the next record from the current recordset.
+
+ Return `!None` the recordset is finished.
+
+ :rtype: Optional[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ record = self._tx.load_row(self._pos, self._make_row)
+ if record is not None:
+ self._pos += 1
+ return record
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ """
+ Return the next `!size` records from the current recordset.
+
+ `!size` default to `!self.arraysize` if not specified.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ def fetchall(self) -> List[Row]:
+ """
+ Return all the remaining records from the current recordset.
+
+ :rtype: Sequence[Row], with Row defined by `row_factory`
+ """
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ def __iter__(self) -> Iterator[Row]:
+ self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ """
+ Move the cursor in the result set to a new position according to mode.
+
+ If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
+ the current position in the result set; if set to ``'absolute'``,
+ `!value` states an absolute target position.
+
+ Raise `!IndexError` in case a scroll operation would leave the result
+ set. In this case the position will not change.
+ """
+ self._fetch_pipeline()
+ self._scroll(value, mode)
+
+ @contextmanager
+ def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[CopyWriter] = None,
+ ) -> Iterator[Copy]:
+ """
+ Initiate a :sql:`COPY` operation and return an object to manage it.
+
+ :rtype: Copy
+ """
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._start_copy_gen(statement, params))
+
+ with Copy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ # If a fresher result has been set on the cursor by the Copy object,
+ # read its properties (especially rowcount).
+ self._select_current_result(0)
+
+ def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ with self._conn.lock:
+ self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py
new file mode 100644
index 0000000..8971d40
--- /dev/null
+++ b/psycopg/psycopg/cursor_async.py
@@ -0,0 +1,250 @@
+"""
+psycopg async cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from types import TracebackType
+from typing import Any, AsyncIterator, Iterable, List
+from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
+from contextlib import asynccontextmanager
+
+from . import pq
+from . import errors as e
+from .abc import Query, Params
+from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
+from .rows import Row, RowMaker, AsyncRowFactory
+from .cursor import BaseCursor
+from ._pipeline import Pipeline
+
+if TYPE_CHECKING:
+ from .connection_async import AsyncConnection
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncCursor[Any]")
+
+ @overload
+ def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: AsyncRowFactory[Row],
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ ):
+ super().__init__(connection)
+ self._row_factory = row_factory or connection.row_factory
+
+ async def __aenter__(self: _Self) -> _Self:
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ async def close(self) -> None:
+ self._close()
+
+ @property
+ def row_factory(self) -> AsyncRowFactory[Row]:
+ return self._row_factory
+
+ @row_factory.setter
+ def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
+ self._row_factory = row_factory
+ if self.pgresult:
+ self._make_row = row_factory(self)
+
+ def _make_row_maker(self) -> RowMaker[Row]:
+ return self._row_factory(self)
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ prepare: Optional[bool] = None,
+ binary: Optional[bool] = None,
+ ) -> _Self:
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = False,
+ ) -> None:
+ try:
+ if Pipeline.is_supported():
+ # If there is already a pipeline, ride it, in order to avoid
+ # sending unnecessary Sync.
+ async with self._conn.lock:
+ p = self._conn._pipeline
+ if p:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ # Otherwise, make a new one
+ if not p:
+ async with self._conn.pipeline(), self._conn.lock:
+ await self._conn.wait(
+ self._executemany_gen_pipeline(query, params_seq, returning)
+ )
+ else:
+ await self._conn.wait(
+ self._executemany_gen_no_pipeline(query, params_seq, returning)
+ )
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ async def stream(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ ) -> AsyncIterator[Row]:
+ if self._pgconn.pipeline_status:
+ raise e.ProgrammingError("stream() cannot be used in pipeline mode")
+
+ async with self._conn.lock:
+
+ try:
+ await self._conn.wait(
+ self._stream_send_gen(query, params, binary=binary)
+ )
+ first = True
+ while await self._conn.wait(self._stream_fetchone_gen(first)):
+ # We know that, if we got a result, it has a single row.
+ rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
+ yield rec
+ first = False
+
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ finally:
+ if self._pgconn.transaction_status == ACTIVE:
+ # Try to cancel the query, then consume the results
+ # already received.
+ self._conn.cancel()
+ try:
+ while await self._conn.wait(
+ self._stream_fetchone_gen(first=False)
+ ):
+ pass
+ except Exception:
+ pass
+
+ # Try to get out of ACTIVE state. Just do a single attempt, which
+ # should work to recover from an error or query cancelled.
+ try:
+ await self._conn.wait(self._stream_fetchone_gen(first=False))
+ except Exception:
+ pass
+
+ async def fetchone(self) -> Optional[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ rv = self._tx.load_row(self._pos, self._make_row)
+ if rv is not None:
+ self._pos += 1
+ return rv
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+
+ if not size:
+ size = self.arraysize
+ records = self._tx.load_rows(
+ self._pos,
+ min(self._pos + size, self.pgresult.ntuples),
+ self._make_row,
+ )
+ self._pos += len(records)
+ return records
+
+ async def fetchall(self) -> List[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+ assert self.pgresult
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
+ self._pos = self.pgresult.ntuples
+ return records
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ await self._fetch_pipeline()
+ self._check_result_for_fetch()
+
+ def load(pos: int) -> Optional[Row]:
+ return self._tx.load_row(pos, self._make_row)
+
+ while True:
+ row = load(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ yield row
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ self._scroll(value, mode)
+
+ @asynccontextmanager
+ async def copy(
+ self,
+ statement: Query,
+ params: Optional[Params] = None,
+ *,
+ writer: Optional[AsyncCopyWriter] = None,
+ ) -> AsyncIterator[AsyncCopy]:
+ """
+ :rtype: AsyncCopy
+ """
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._start_copy_gen(statement, params))
+
+ async with AsyncCopy(self, writer=writer) as copy:
+ yield copy
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ self._select_current_result(0)
+
+ async def _fetch_pipeline(self) -> None:
+ if (
+ self._execmany_returning is not False
+ and not self.pgresult
+ and self._conn._pipeline
+ ):
+ async with self._conn.lock:
+ await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py
new file mode 100644
index 0000000..3c3d8b7
--- /dev/null
+++ b/psycopg/psycopg/dbapi20.py
@@ -0,0 +1,112 @@
+"""
+Compatibility objects with DBAPI 2.0
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import time
+import datetime as dt
+from math import floor
+from typing import Any, Sequence, Union
+
+from . import postgres
+from .abc import AdaptContext, Buffer
+from .types.string import BytesDumper, BytesBinaryDumper
+
+
+class DBAPITypeObject:
+ def __init__(self, name: str, type_names: Sequence[str]):
+ self.name = name
+ self.values = tuple(postgres.types[n].oid for n in type_names)
+
+ def __repr__(self) -> str:
+ return f"psycopg.{self.name}"
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other in self.values
+ else:
+ return NotImplemented
+
+ def __ne__(self, other: Any) -> bool:
+ if isinstance(other, int):
+ return other not in self.values
+ else:
+ return NotImplemented
+
+
+BINARY = DBAPITypeObject("BINARY", ("bytea",))
+DATETIME = DBAPITypeObject(
+ "DATETIME", "timestamp timestamptz date time timetz interval".split()
+)
+NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
+ROWID = DBAPITypeObject("ROWID", ("oid",))
+STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
+
+
+class Binary:
+ def __init__(self, obj: Any):
+ self.obj = obj
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class BinaryBinaryDumper(BytesBinaryDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+class BinaryTextDumper(BytesDumper):
+ def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+ if isinstance(obj, Binary):
+ return super().dump(obj.obj)
+ else:
+ return super().dump(obj)
+
+
+def Date(year: int, month: int, day: int) -> dt.date:
+ return dt.date(year, month, day)
+
+
+def DateFromTicks(ticks: float) -> dt.date:
+ return TimestampFromTicks(ticks).date()
+
+
+def Time(hour: int, minute: int, second: int) -> dt.time:
+ return dt.time(hour, minute, second)
+
+
+def TimeFromTicks(ticks: float) -> dt.time:
+ return TimestampFromTicks(ticks).time()
+
+
+def Timestamp(
+ year: int, month: int, day: int, hour: int, minute: int, second: int
+) -> dt.datetime:
+ return dt.datetime(year, month, day, hour, minute, second)
+
+
+def TimestampFromTicks(ticks: float) -> dt.datetime:
+ secs = floor(ticks)
+ frac = ticks - secs
+ t = time.localtime(ticks)
+ tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
+ rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
+ return rv
+
+
+def register_dbapi20_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Binary, BinaryTextDumper)
+ adapters.register_dumper(Binary, BinaryBinaryDumper)
+
+ # Make them also the default dumpers when dumping by bytea oid
+ adapters.register_dumper(None, BinaryTextDumper)
+ adapters.register_dumper(None, BinaryBinaryDumper)
diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py
new file mode 100644
index 0000000..e176954
--- /dev/null
+++ b/psycopg/psycopg/errors.py
@@ -0,0 +1,1535 @@
+"""
+psycopg exceptions
+
+DBAPI-defined Exceptions are defined in the following hierarchy::
+
+ Exceptions
+ |__Warning
+ |__Error
+ |__InterfaceError
+ |__DatabaseError
+ |__DataError
+ |__OperationalError
+ |__IntegrityError
+ |__InternalError
+ |__ProgrammingError
+ |__NotSupportedError
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
+from typing_extensions import TypeAlias
+
+from .pq.abc import PGconn, PGresult
+from .pq._enums import DiagnosticField
+from ._compat import TypeGuard
+
+ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]]
+
+_sqlcodes: Dict[str, "Type[Error]"] = {}
+
+
+class Warning(Exception):
+ """
+ Exception raised for important warnings.
+
+ Defined for DBAPI compatibility, but never raised by ``psycopg``.
+ """
+
+ __module__ = "psycopg"
+
+
+class Error(Exception):
+ """
+ Base exception for all the errors psycopg will raise.
+
+ Exception that is the base class of all other error exceptions. You can
+ use this to catch all errors with one single `!except` statement.
+
+ This exception is guaranteed to be picklable.
+ """
+
+ __module__ = "psycopg"
+
+ sqlstate: Optional[str] = None
+
+ def __init__(
+ self,
+ *args: Sequence[Any],
+ info: ErrorInfo = None,
+ encoding: str = "utf-8",
+ pgconn: Optional[PGconn] = None
+ ):
+ super().__init__(*args)
+ self._info = info
+ self._encoding = encoding
+ self._pgconn = pgconn
+
+ # Handle sqlstate codes for which we don't have a class.
+ if not self.sqlstate and info:
+ self.sqlstate = self.diag.sqlstate
+
+ @property
+ def pgconn(self) -> Optional[PGconn]:
+ """The connection object, if the error was raised from a connection attempt.
+
+ :rtype: Optional[psycopg.pq.PGconn]
+ """
+ return self._pgconn if self._pgconn else None
+
+ @property
+ def pgresult(self) -> Optional[PGresult]:
+ """The result object, if the exception was raised after a failed query.
+
+ :rtype: Optional[psycopg.pq.PGresult]
+ """
+ return self._info if _is_pgresult(self._info) else None
+
+ @property
+ def diag(self) -> "Diagnostic":
+ """
+ A `Diagnostic` object to inspect details of the errors from the database.
+ """
+ return Diagnostic(self._info, encoding=self._encoding)
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ # To make the exception picklable
+ res[2]["_info"] = _info_to_dict(self._info)
+ res[2]["_pgconn"] = None
+
+ return res
+
+
+class InterfaceError(Error):
+ """
+ An error related to the database interface rather than the database itself.
+ """
+
+ __module__ = "psycopg"
+
+
+class DatabaseError(Error):
+ """
+ Exception raised for errors that are related to the database.
+ """
+
+ __module__ = "psycopg"
+
+ def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None):
+ if code:
+ _sqlcodes[code] = cls
+ cls.sqlstate = code
+ if name:
+ _sqlcodes[name] = cls
+
+
+class DataError(DatabaseError):
+ """
+ An error caused by problems with the processed data.
+
+ Examples may be division by zero, numeric value out of range, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class OperationalError(DatabaseError):
+ """
+ An error related to the database's operation.
+
+ These errors are not necessarily under the control of the programmer, e.g.
+ an unexpected disconnect occurs, the data source name is not found, a
+ transaction could not be processed, a memory allocation error occurred
+ during processing, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class IntegrityError(DatabaseError):
+ """
+ An error caused when the relational integrity of the database is affected.
+
+ An example may be a foreign key check failed.
+ """
+
+ __module__ = "psycopg"
+
+
+class InternalError(DatabaseError):
+ """
+ An error generated when the database encounters an internal error,
+
+ Examples could be the cursor is not valid anymore, the transaction is out
+ of sync, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class ProgrammingError(DatabaseError):
+ """
+ Exception raised for programming errors
+
+ Examples may be table not found or already exists, syntax error in the SQL
+ statement, wrong number of parameters specified, etc.
+ """
+
+ __module__ = "psycopg"
+
+
+class NotSupportedError(DatabaseError):
+ """
+ A method or database API was used which is not supported by the database.
+ """
+
+ __module__ = "psycopg"
+
+
+class ConnectionTimeout(OperationalError):
+ """
+ Exception raised on timeout of the `~psycopg.Connection.connect()` method.
+
+ The error is raised if the ``connect_timeout`` is specified and a
+ connection is not obtained in useful time.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class PipelineAborted(OperationalError):
+ """
+ Raised when a operation fails because the current pipeline is in aborted state.
+
+ Subclass of `~psycopg.OperationalError`.
+ """
+
+
+class Diagnostic:
+ """Details from a database error report."""
+
+ def __init__(self, info: ErrorInfo, encoding: str = "utf-8"):
+ self._info = info
+ self._encoding = encoding
+
+ @property
+ def severity(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY)
+
+ @property
+ def severity_nonlocalized(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED)
+
+ @property
+ def sqlstate(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SQLSTATE)
+
+ @property
+ def message_primary(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_PRIMARY)
+
+ @property
+ def message_detail(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_DETAIL)
+
+ @property
+ def message_hint(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.MESSAGE_HINT)
+
+ @property
+ def statement_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.STATEMENT_POSITION)
+
+ @property
+ def internal_position(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_POSITION)
+
+ @property
+ def internal_query(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.INTERNAL_QUERY)
+
+ @property
+ def context(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONTEXT)
+
+ @property
+ def schema_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SCHEMA_NAME)
+
+ @property
+ def table_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.TABLE_NAME)
+
+ @property
+ def column_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.COLUMN_NAME)
+
+ @property
+ def datatype_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.DATATYPE_NAME)
+
+ @property
+ def constraint_name(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.CONSTRAINT_NAME)
+
+ @property
+ def source_file(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FILE)
+
+ @property
+ def source_line(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_LINE)
+
+ @property
+ def source_function(self) -> Optional[str]:
+ return self._error_message(DiagnosticField.SOURCE_FUNCTION)
+
+ def _error_message(self, field: DiagnosticField) -> Optional[str]:
+ if self._info:
+ if isinstance(self._info, dict):
+ val = self._info.get(field)
+ else:
+ val = self._info.error_field(field)
+
+ if val is not None:
+ return val.decode(self._encoding, "replace")
+
+ return None
+
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
+ res = super().__reduce__()
+ if isinstance(res, tuple) and len(res) >= 3:
+ res[2]["_info"] = _info_to_dict(self._info)
+
+ return res
+
+
+def _info_to_dict(info: ErrorInfo) -> ErrorInfo:
+ """
+ Convert a PGresult to a dictionary to make the info picklable.
+ """
+ # PGresult is a protocol, can't use isinstance
+ if _is_pgresult(info):
+ return {v: info.error_field(v) for v in DiagnosticField}
+ else:
+ return info
+
+
+def lookup(sqlstate: str) -> Type[Error]:
+ """Lookup an error code or `constant name`__ and return its exception class.
+
+ Raise `!KeyError` if the code is not found.
+
+ .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html
+ #ERRCODES-TABLE
+ """
+ return _sqlcodes[sqlstate.upper()]
+
+
+def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error:
+ from psycopg import pq
+
+ state = result.error_field(DiagnosticField.SQLSTATE) or b""
+ cls = _class_for_state(state.decode("ascii"))
+ return cls(
+ pq.error_message(result, encoding=encoding),
+ info=result,
+ encoding=encoding,
+ )
+
+
+def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]:
+ """Return True if an ErrorInfo is a PGresult instance."""
+ # PGresult is a protocol, can't use isinstance
+ return hasattr(info, "error_field")
+
+
+def _class_for_state(sqlstate: str) -> Type[Error]:
+ try:
+ return lookup(sqlstate)
+ except KeyError:
+ return get_base_exception(sqlstate)
+
+
+def get_base_exception(sqlstate: str) -> Type[Error]:
+ return (
+ _base_exc_map.get(sqlstate[:2])
+ or _base_exc_map.get(sqlstate[:1])
+ or DatabaseError
+ )
+
+
+_base_exc_map = {
+ "08": OperationalError, # Connection Exception
+ "0A": NotSupportedError, # Feature Not Supported
+ "20": ProgrammingError, # Case Not Foud
+ "21": ProgrammingError, # Cardinality Violation
+ "22": DataError, # Data Exception
+ "23": IntegrityError, # Integrity Constraint Violation
+ "24": InternalError, # Invalid Cursor State
+ "25": InternalError, # Invalid Transaction State
+ "26": ProgrammingError, # Invalid SQL Statement Name *
+ "27": OperationalError, # Triggered Data Change Violation
+ "28": OperationalError, # Invalid Authorization Specification
+ "2B": InternalError, # Dependent Privilege Descriptors Still Exist
+ "2D": InternalError, # Invalid Transaction Termination
+ "2F": OperationalError, # SQL Routine Exception *
+ "34": ProgrammingError, # Invalid Cursor Name *
+ "38": OperationalError, # External Routine Exception *
+ "39": OperationalError, # External Routine Invocation Exception *
+ "3B": OperationalError, # Savepoint Exception *
+ "3D": ProgrammingError, # Invalid Catalog Name
+ "3F": ProgrammingError, # Invalid Schema Name
+ "40": OperationalError, # Transaction Rollback
+ "42": ProgrammingError, # Syntax Error or Access Rule Violation
+ "44": ProgrammingError, # WITH CHECK OPTION Violation
+ "53": OperationalError, # Insufficient Resources
+ "54": OperationalError, # Program Limit Exceeded
+ "55": OperationalError, # Object Not In Prerequisite State
+ "57": OperationalError, # Operator Intervention
+ "58": OperationalError, # System Error (errors external to PostgreSQL itself)
+ "F": OperationalError, # Configuration File Error
+ "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED)
+ "P": ProgrammingError, # PL/pgSQL Error
+ "X": InternalError, # Internal Error
+}
+
+
+# Error classes generated by tools/update_errors.py
+
+# fmt: off
+# autogenerated: start
+
+
+# Class 02 - No Data (this is also a warning class per the SQL standard)
+
+class NoData(DatabaseError,
+ code='02000', name='NO_DATA'):
+ pass
+
+class NoAdditionalDynamicResultSetsReturned(DatabaseError,
+ code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'):
+ pass
+
+
+# Class 03 - SQL Statement Not Yet Complete
+
+class SqlStatementNotYetComplete(DatabaseError,
+ code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'):
+ pass
+
+
+# Class 08 - Connection Exception
+
+class ConnectionException(OperationalError,
+ code='08000', name='CONNECTION_EXCEPTION'):
+ pass
+
+class SqlclientUnableToEstablishSqlconnection(OperationalError,
+ code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'):
+ pass
+
+class ConnectionDoesNotExist(OperationalError,
+ code='08003', name='CONNECTION_DOES_NOT_EXIST'):
+ pass
+
+class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError,
+ code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'):
+ pass
+
+class ConnectionFailure(OperationalError,
+ code='08006', name='CONNECTION_FAILURE'):
+ pass
+
+class TransactionResolutionUnknown(OperationalError,
+ code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'):
+ pass
+
+class ProtocolViolation(OperationalError,
+ code='08P01', name='PROTOCOL_VIOLATION'):
+ pass
+
+
+# Class 09 - Triggered Action Exception
+
+class TriggeredActionException(DatabaseError,
+ code='09000', name='TRIGGERED_ACTION_EXCEPTION'):
+ pass
+
+
+# Class 0A - Feature Not Supported
+
+class FeatureNotSupported(NotSupportedError,
+ code='0A000', name='FEATURE_NOT_SUPPORTED'):
+ pass
+
+
+# Class 0B - Invalid Transaction Initiation
+
+class InvalidTransactionInitiation(DatabaseError,
+ code='0B000', name='INVALID_TRANSACTION_INITIATION'):
+ pass
+
+
+# Class 0F - Locator Exception
+
+class LocatorException(DatabaseError,
+ code='0F000', name='LOCATOR_EXCEPTION'):
+ pass
+
+class InvalidLocatorSpecification(DatabaseError,
+ code='0F001', name='INVALID_LOCATOR_SPECIFICATION'):
+ pass
+
+
+# Class 0L - Invalid Grantor
+
+class InvalidGrantor(DatabaseError,
+ code='0L000', name='INVALID_GRANTOR'):
+ pass
+
+class InvalidGrantOperation(DatabaseError,
+ code='0LP01', name='INVALID_GRANT_OPERATION'):
+ pass
+
+
+# Class 0P - Invalid Role Specification
+
+class InvalidRoleSpecification(DatabaseError,
+ code='0P000', name='INVALID_ROLE_SPECIFICATION'):
+ pass
+
+
+# Class 0Z - Diagnostics Exception
+
+class DiagnosticsException(DatabaseError,
+ code='0Z000', name='DIAGNOSTICS_EXCEPTION'):
+ pass
+
+class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError,
+ code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'):
+ pass
+
+
+# Class 20 - Case Not Found
+
+class CaseNotFound(ProgrammingError,
+ code='20000', name='CASE_NOT_FOUND'):
+ pass
+
+
+# Class 21 - Cardinality Violation
+
+class CardinalityViolation(ProgrammingError,
+ code='21000', name='CARDINALITY_VIOLATION'):
+ pass
+
+
+# Class 22 - Data Exception
+
+class DataException(DataError,
+ code='22000', name='DATA_EXCEPTION'):
+ pass
+
+class StringDataRightTruncation(DataError,
+ code='22001', name='STRING_DATA_RIGHT_TRUNCATION'):
+ pass
+
+class NullValueNoIndicatorParameter(DataError,
+ code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'):
+ pass
+
+class NumericValueOutOfRange(DataError,
+ code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'):
+ pass
+
+class NullValueNotAllowed(DataError,
+ code='22004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class ErrorInAssignment(DataError,
+ code='22005', name='ERROR_IN_ASSIGNMENT'):
+ pass
+
+class InvalidDatetimeFormat(DataError,
+ code='22007', name='INVALID_DATETIME_FORMAT'):
+ pass
+
+class DatetimeFieldOverflow(DataError,
+ code='22008', name='DATETIME_FIELD_OVERFLOW'):
+ pass
+
+class InvalidTimeZoneDisplacementValue(DataError,
+ code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'):
+ pass
+
+class EscapeCharacterConflict(DataError,
+ code='2200B', name='ESCAPE_CHARACTER_CONFLICT'):
+ pass
+
+class InvalidUseOfEscapeCharacter(DataError,
+ code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidEscapeOctet(DataError,
+ code='2200D', name='INVALID_ESCAPE_OCTET'):
+ pass
+
+class ZeroLengthCharacterString(DataError,
+ code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'):
+ pass
+
+class MostSpecificTypeMismatch(DataError,
+ code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'):
+ pass
+
+class SequenceGeneratorLimitExceeded(DataError,
+ code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'):
+ pass
+
+class NotAnXmlDocument(DataError,
+ code='2200L', name='NOT_AN_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlDocument(DataError,
+ code='2200M', name='INVALID_XML_DOCUMENT'):
+ pass
+
+class InvalidXmlContent(DataError,
+ code='2200N', name='INVALID_XML_CONTENT'):
+ pass
+
+class InvalidXmlComment(DataError,
+ code='2200S', name='INVALID_XML_COMMENT'):
+ pass
+
+class InvalidXmlProcessingInstruction(DataError,
+ code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'):
+ pass
+
+class InvalidIndicatorParameterValue(DataError,
+ code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'):
+ pass
+
+class SubstringError(DataError,
+ code='22011', name='SUBSTRING_ERROR'):
+ pass
+
+class DivisionByZero(DataError,
+ code='22012', name='DIVISION_BY_ZERO'):
+ pass
+
+class InvalidPrecedingOrFollowingSize(DataError,
+ code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'):
+ pass
+
+class InvalidArgumentForNtileFunction(DataError,
+ code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'):
+ pass
+
+class IntervalFieldOverflow(DataError,
+ code='22015', name='INTERVAL_FIELD_OVERFLOW'):
+ pass
+
+class InvalidArgumentForNthValueFunction(DataError,
+ code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'):
+ pass
+
+class InvalidCharacterValueForCast(DataError,
+ code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'):
+ pass
+
+class InvalidEscapeCharacter(DataError,
+ code='22019', name='INVALID_ESCAPE_CHARACTER'):
+ pass
+
+class InvalidRegularExpression(DataError,
+ code='2201B', name='INVALID_REGULAR_EXPRESSION'):
+ pass
+
+class InvalidArgumentForLogarithm(DataError,
+ code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'):
+ pass
+
+class InvalidArgumentForPowerFunction(DataError,
+ code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'):
+ pass
+
+class InvalidArgumentForWidthBucketFunction(DataError,
+ code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'):
+ pass
+
+class InvalidRowCountInLimitClause(DataError,
+ code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'):
+ pass
+
+class InvalidRowCountInResultOffsetClause(DataError,
+ code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'):
+ pass
+
+class CharacterNotInRepertoire(DataError,
+ code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'):
+ pass
+
+class IndicatorOverflow(DataError,
+ code='22022', name='INDICATOR_OVERFLOW'):
+ pass
+
+class InvalidParameterValue(DataError,
+ code='22023', name='INVALID_PARAMETER_VALUE'):
+ pass
+
+class UnterminatedCString(DataError,
+ code='22024', name='UNTERMINATED_C_STRING'):
+ pass
+
+class InvalidEscapeSequence(DataError,
+ code='22025', name='INVALID_ESCAPE_SEQUENCE'):
+ pass
+
+class StringDataLengthMismatch(DataError,
+ code='22026', name='STRING_DATA_LENGTH_MISMATCH'):
+ pass
+
+class TrimError(DataError,
+ code='22027', name='TRIM_ERROR'):
+ pass
+
+class ArraySubscriptError(DataError,
+ code='2202E', name='ARRAY_SUBSCRIPT_ERROR'):
+ pass
+
+class InvalidTablesampleRepeat(DataError,
+ code='2202G', name='INVALID_TABLESAMPLE_REPEAT'):
+ pass
+
+class InvalidTablesampleArgument(DataError,
+ code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'):
+ pass
+
+class DuplicateJsonObjectKeyValue(DataError,
+ code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'):
+ pass
+
+class InvalidArgumentForSqlJsonDatetimeFunction(DataError,
+ code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'):
+ pass
+
+class InvalidJsonText(DataError,
+ code='22032', name='INVALID_JSON_TEXT'):
+ pass
+
+class InvalidSqlJsonSubscript(DataError,
+ code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'):
+ pass
+
+class MoreThanOneSqlJsonItem(DataError,
+ code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'):
+ pass
+
+class NoSqlJsonItem(DataError,
+ code='22035', name='NO_SQL_JSON_ITEM'):
+ pass
+
+class NonNumericSqlJsonItem(DataError,
+ code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'):
+ pass
+
+class NonUniqueKeysInAJsonObject(DataError,
+ code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'):
+ pass
+
+class SingletonSqlJsonItemRequired(DataError,
+ code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'):
+ pass
+
+class SqlJsonArrayNotFound(DataError,
+ code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'):
+ pass
+
+class SqlJsonMemberNotFound(DataError,
+ code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonNumberNotFound(DataError,
+ code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'):
+ pass
+
+class SqlJsonObjectNotFound(DataError,
+ code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'):
+ pass
+
+class TooManyJsonArrayElements(DataError,
+ code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'):
+ pass
+
+class TooManyJsonObjectMembers(DataError,
+ code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'):
+ pass
+
+class SqlJsonScalarRequired(DataError,
+ code='2203F', name='SQL_JSON_SCALAR_REQUIRED'):
+ pass
+
+class SqlJsonItemCannotBeCastToTargetType(DataError,
+ code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'):
+ pass
+
+class FloatingPointException(DataError,
+ code='22P01', name='FLOATING_POINT_EXCEPTION'):
+ pass
+
+class InvalidTextRepresentation(DataError,
+ code='22P02', name='INVALID_TEXT_REPRESENTATION'):
+ pass
+
+class InvalidBinaryRepresentation(DataError,
+ code='22P03', name='INVALID_BINARY_REPRESENTATION'):
+ pass
+
+class BadCopyFileFormat(DataError,
+ code='22P04', name='BAD_COPY_FILE_FORMAT'):
+ pass
+
+class UntranslatableCharacter(DataError,
+ code='22P05', name='UNTRANSLATABLE_CHARACTER'):
+ pass
+
+class NonstandardUseOfEscapeCharacter(DataError,
+ code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'):
+ pass
+
+
+# Class 23 - Integrity Constraint Violation
+
+class IntegrityConstraintViolation(IntegrityError,
+ code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class RestrictViolation(IntegrityError,
+ code='23001', name='RESTRICT_VIOLATION'):
+ pass
+
+class NotNullViolation(IntegrityError,
+ code='23502', name='NOT_NULL_VIOLATION'):
+ pass
+
+class ForeignKeyViolation(IntegrityError,
+ code='23503', name='FOREIGN_KEY_VIOLATION'):
+ pass
+
+class UniqueViolation(IntegrityError,
+ code='23505', name='UNIQUE_VIOLATION'):
+ pass
+
+class CheckViolation(IntegrityError,
+ code='23514', name='CHECK_VIOLATION'):
+ pass
+
+class ExclusionViolation(IntegrityError,
+ code='23P01', name='EXCLUSION_VIOLATION'):
+ pass
+
+
+# Class 24 - Invalid Cursor State
+
+class InvalidCursorState(InternalError,
+ code='24000', name='INVALID_CURSOR_STATE'):
+ pass
+
+
+# Class 25 - Invalid Transaction State
+
+class InvalidTransactionState(InternalError,
+ code='25000', name='INVALID_TRANSACTION_STATE'):
+ pass
+
+class ActiveSqlTransaction(InternalError,
+ code='25001', name='ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class BranchTransactionAlreadyActive(InternalError,
+ code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'):
+ pass
+
+class InappropriateAccessModeForBranchTransaction(InternalError,
+ code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class InappropriateIsolationLevelForBranchTransaction(InternalError,
+ code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class NoActiveSqlTransactionForBranchTransaction(InternalError,
+ code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'):
+ pass
+
+class ReadOnlySqlTransaction(InternalError,
+ code='25006', name='READ_ONLY_SQL_TRANSACTION'):
+ pass
+
+class SchemaAndDataStatementMixingNotSupported(InternalError,
+ code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'):
+ pass
+
+class HeldCursorRequiresSameIsolationLevel(InternalError,
+ code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'):
+ pass
+
+class NoActiveSqlTransaction(InternalError,
+ code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'):
+ pass
+
+class InFailedSqlTransaction(InternalError,
+ code='25P02', name='IN_FAILED_SQL_TRANSACTION'):
+ pass
+
+class IdleInTransactionSessionTimeout(InternalError,
+ code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 26 - Invalid SQL Statement Name
+
+class InvalidSqlStatementName(ProgrammingError,
+ code='26000', name='INVALID_SQL_STATEMENT_NAME'):
+ pass
+
+
+# Class 27 - Triggered Data Change Violation
+
+class TriggeredDataChangeViolation(OperationalError,
+ code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'):
+ pass
+
+
+# Class 28 - Invalid Authorization Specification
+
+class InvalidAuthorizationSpecification(OperationalError,
+ code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'):
+ pass
+
+class InvalidPassword(OperationalError,
+ code='28P01', name='INVALID_PASSWORD'):
+ pass
+
+
+# Class 2B - Dependent Privilege Descriptors Still Exist
+
+class DependentPrivilegeDescriptorsStillExist(InternalError,
+ code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'):
+ pass
+
+class DependentObjectsStillExist(InternalError,
+ code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'):
+ pass
+
+
+# Class 2D - Invalid Transaction Termination
+
+class InvalidTransactionTermination(InternalError,
+ code='2D000', name='INVALID_TRANSACTION_TERMINATION'):
+ pass
+
+
+# Class 2F - SQL Routine Exception
+
+class SqlRoutineException(OperationalError,
+ code='2F000', name='SQL_ROUTINE_EXCEPTION'):
+ pass
+
+class ModifyingSqlDataNotPermitted(OperationalError,
+ code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttempted(OperationalError,
+ code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermitted(OperationalError,
+ code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class FunctionExecutedNoReturnStatement(OperationalError,
+ code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'):
+ pass
+
+
+# Class 34 - Invalid Cursor Name
+
+class InvalidCursorName(ProgrammingError,
+ code='34000', name='INVALID_CURSOR_NAME'):
+ pass
+
+
+# Class 38 - External Routine Exception
+
+class ExternalRoutineException(OperationalError,
+ code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'):
+ pass
+
+class ContainingSqlNotPermitted(OperationalError,
+ code='38001', name='CONTAINING_SQL_NOT_PERMITTED'):
+ pass
+
+class ModifyingSqlDataNotPermittedExt(OperationalError,
+ code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+class ProhibitedSqlStatementAttemptedExt(OperationalError,
+ code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'):
+ pass
+
+class ReadingSqlDataNotPermittedExt(OperationalError,
+ code='38004', name='READING_SQL_DATA_NOT_PERMITTED'):
+ pass
+
+
+# Class 39 - External Routine Invocation Exception
+
+class ExternalRoutineInvocationException(OperationalError,
+ code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'):
+ pass
+
+class InvalidSqlstateReturned(OperationalError,
+ code='39001', name='INVALID_SQLSTATE_RETURNED'):
+ pass
+
+class NullValueNotAllowedExt(OperationalError,
+ code='39004', name='NULL_VALUE_NOT_ALLOWED'):
+ pass
+
+class TriggerProtocolViolated(OperationalError,
+ code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+class SrfProtocolViolated(OperationalError,
+ code='39P02', name='SRF_PROTOCOL_VIOLATED'):
+ pass
+
+class EventTriggerProtocolViolated(OperationalError,
+ code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'):
+ pass
+
+
+# Class 3B - Savepoint Exception
+
+class SavepointException(OperationalError,
+ code='3B000', name='SAVEPOINT_EXCEPTION'):
+ pass
+
+class InvalidSavepointSpecification(OperationalError,
+ code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'):
+ pass
+
+
+# Class 3D - Invalid Catalog Name
+
+class InvalidCatalogName(ProgrammingError,
+ code='3D000', name='INVALID_CATALOG_NAME'):
+ pass
+
+
+# Class 3F - Invalid Schema Name
+
+class InvalidSchemaName(ProgrammingError,
+ code='3F000', name='INVALID_SCHEMA_NAME'):
+ pass
+
+
+# Class 40 - Transaction Rollback
+
+class TransactionRollback(OperationalError,
+ code='40000', name='TRANSACTION_ROLLBACK'):
+ pass
+
+class SerializationFailure(OperationalError,
+ code='40001', name='SERIALIZATION_FAILURE'):
+ pass
+
+class TransactionIntegrityConstraintViolation(OperationalError,
+ code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'):
+ pass
+
+class StatementCompletionUnknown(OperationalError,
+ code='40003', name='STATEMENT_COMPLETION_UNKNOWN'):
+ pass
+
+class DeadlockDetected(OperationalError,
+ code='40P01', name='DEADLOCK_DETECTED'):
+ pass
+
+
+# Class 42 - Syntax Error or Access Rule Violation
+
+class SyntaxErrorOrAccessRuleViolation(ProgrammingError,
+ code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'):
+ pass
+
+class InsufficientPrivilege(ProgrammingError,
+ code='42501', name='INSUFFICIENT_PRIVILEGE'):
+ pass
+
+class SyntaxError(ProgrammingError,
+ code='42601', name='SYNTAX_ERROR'):
+ pass
+
+class InvalidName(ProgrammingError,
+ code='42602', name='INVALID_NAME'):
+ pass
+
+class InvalidColumnDefinition(ProgrammingError,
+ code='42611', name='INVALID_COLUMN_DEFINITION'):
+ pass
+
+class NameTooLong(ProgrammingError,
+ code='42622', name='NAME_TOO_LONG'):
+ pass
+
+class DuplicateColumn(ProgrammingError,
+ code='42701', name='DUPLICATE_COLUMN'):
+ pass
+
+class AmbiguousColumn(ProgrammingError,
+ code='42702', name='AMBIGUOUS_COLUMN'):
+ pass
+
+class UndefinedColumn(ProgrammingError,
+ code='42703', name='UNDEFINED_COLUMN'):
+ pass
+
+class UndefinedObject(ProgrammingError,
+ code='42704', name='UNDEFINED_OBJECT'):
+ pass
+
+class DuplicateObject(ProgrammingError,
+ code='42710', name='DUPLICATE_OBJECT'):
+ pass
+
+class DuplicateAlias(ProgrammingError,
+ code='42712', name='DUPLICATE_ALIAS'):
+ pass
+
+class DuplicateFunction(ProgrammingError,
+ code='42723', name='DUPLICATE_FUNCTION'):
+ pass
+
+class AmbiguousFunction(ProgrammingError,
+ code='42725', name='AMBIGUOUS_FUNCTION'):
+ pass
+
+class GroupingError(ProgrammingError,
+ code='42803', name='GROUPING_ERROR'):
+ pass
+
+class DatatypeMismatch(ProgrammingError,
+ code='42804', name='DATATYPE_MISMATCH'):
+ pass
+
+class WrongObjectType(ProgrammingError,
+ code='42809', name='WRONG_OBJECT_TYPE'):
+ pass
+
+class InvalidForeignKey(ProgrammingError,
+ code='42830', name='INVALID_FOREIGN_KEY'):
+ pass
+
+class CannotCoerce(ProgrammingError,
+ code='42846', name='CANNOT_COERCE'):
+ pass
+
+class UndefinedFunction(ProgrammingError,
+ code='42883', name='UNDEFINED_FUNCTION'):
+ pass
+
+class GeneratedAlways(ProgrammingError,
+ code='428C9', name='GENERATED_ALWAYS'):
+ pass
+
+class ReservedName(ProgrammingError,
+ code='42939', name='RESERVED_NAME'):
+ pass
+
+class UndefinedTable(ProgrammingError,
+ code='42P01', name='UNDEFINED_TABLE'):
+ pass
+
+class UndefinedParameter(ProgrammingError,
+ code='42P02', name='UNDEFINED_PARAMETER'):
+ pass
+
+class DuplicateCursor(ProgrammingError,
+ code='42P03', name='DUPLICATE_CURSOR'):
+ pass
+
+class DuplicateDatabase(ProgrammingError,
+ code='42P04', name='DUPLICATE_DATABASE'):
+ pass
+
+class DuplicatePreparedStatement(ProgrammingError,
+ code='42P05', name='DUPLICATE_PREPARED_STATEMENT'):
+ pass
+
+class DuplicateSchema(ProgrammingError,
+ code='42P06', name='DUPLICATE_SCHEMA'):
+ pass
+
+class DuplicateTable(ProgrammingError,
+ code='42P07', name='DUPLICATE_TABLE'):
+ pass
+
+class AmbiguousParameter(ProgrammingError,
+ code='42P08', name='AMBIGUOUS_PARAMETER'):
+ pass
+
+class AmbiguousAlias(ProgrammingError,
+ code='42P09', name='AMBIGUOUS_ALIAS'):
+ pass
+
+class InvalidColumnReference(ProgrammingError,
+ code='42P10', name='INVALID_COLUMN_REFERENCE'):
+ pass
+
+class InvalidCursorDefinition(ProgrammingError,
+ code='42P11', name='INVALID_CURSOR_DEFINITION'):
+ pass
+
+class InvalidDatabaseDefinition(ProgrammingError,
+ code='42P12', name='INVALID_DATABASE_DEFINITION'):
+ pass
+
+class InvalidFunctionDefinition(ProgrammingError,
+ code='42P13', name='INVALID_FUNCTION_DEFINITION'):
+ pass
+
+class InvalidPreparedStatementDefinition(ProgrammingError,
+ code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'):
+ pass
+
+class InvalidSchemaDefinition(ProgrammingError,
+ code='42P15', name='INVALID_SCHEMA_DEFINITION'):
+ pass
+
+class InvalidTableDefinition(ProgrammingError,
+ code='42P16', name='INVALID_TABLE_DEFINITION'):
+ pass
+
+class InvalidObjectDefinition(ProgrammingError,
+ code='42P17', name='INVALID_OBJECT_DEFINITION'):
+ pass
+
+class IndeterminateDatatype(ProgrammingError,
+ code='42P18', name='INDETERMINATE_DATATYPE'):
+ pass
+
+class InvalidRecursion(ProgrammingError,
+ code='42P19', name='INVALID_RECURSION'):
+ pass
+
+class WindowingError(ProgrammingError,
+ code='42P20', name='WINDOWING_ERROR'):
+ pass
+
+class CollationMismatch(ProgrammingError,
+ code='42P21', name='COLLATION_MISMATCH'):
+ pass
+
+class IndeterminateCollation(ProgrammingError,
+ code='42P22', name='INDETERMINATE_COLLATION'):
+ pass
+
+
+# Class 44 - WITH CHECK OPTION Violation
+
+class WithCheckOptionViolation(ProgrammingError,
+ code='44000', name='WITH_CHECK_OPTION_VIOLATION'):
+ pass
+
+
+# Class 53 - Insufficient Resources
+
+class InsufficientResources(OperationalError,
+ code='53000', name='INSUFFICIENT_RESOURCES'):
+ pass
+
+class DiskFull(OperationalError,
+ code='53100', name='DISK_FULL'):
+ pass
+
+class OutOfMemory(OperationalError,
+ code='53200', name='OUT_OF_MEMORY'):
+ pass
+
+class TooManyConnections(OperationalError,
+ code='53300', name='TOO_MANY_CONNECTIONS'):
+ pass
+
+class ConfigurationLimitExceeded(OperationalError,
+ code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'):
+ pass
+
+
+# Class 54 - Program Limit Exceeded
+
+class ProgramLimitExceeded(OperationalError,
+ code='54000', name='PROGRAM_LIMIT_EXCEEDED'):
+ pass
+
+class StatementTooComplex(OperationalError,
+ code='54001', name='STATEMENT_TOO_COMPLEX'):
+ pass
+
+class TooManyColumns(OperationalError,
+ code='54011', name='TOO_MANY_COLUMNS'):
+ pass
+
+class TooManyArguments(OperationalError,
+ code='54023', name='TOO_MANY_ARGUMENTS'):
+ pass
+
+
+# Class 55 - Object Not In Prerequisite State
+
+class ObjectNotInPrerequisiteState(OperationalError,
+ code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'):
+ pass
+
+class ObjectInUse(OperationalError,
+ code='55006', name='OBJECT_IN_USE'):
+ pass
+
+class CantChangeRuntimeParam(OperationalError,
+ code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'):
+ pass
+
+class LockNotAvailable(OperationalError,
+ code='55P03', name='LOCK_NOT_AVAILABLE'):
+ pass
+
+class UnsafeNewEnumValueUsage(OperationalError,
+ code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'):
+ pass
+
+
+# Class 57 - Operator Intervention
+
+class OperatorIntervention(OperationalError,
+ code='57000', name='OPERATOR_INTERVENTION'):
+ pass
+
+class QueryCanceled(OperationalError,
+ code='57014', name='QUERY_CANCELED'):
+ pass
+
+class AdminShutdown(OperationalError,
+ code='57P01', name='ADMIN_SHUTDOWN'):
+ pass
+
+class CrashShutdown(OperationalError,
+ code='57P02', name='CRASH_SHUTDOWN'):
+ pass
+
+class CannotConnectNow(OperationalError,
+ code='57P03', name='CANNOT_CONNECT_NOW'):
+ pass
+
+class DatabaseDropped(OperationalError,
+ code='57P04', name='DATABASE_DROPPED'):
+ pass
+
+class IdleSessionTimeout(OperationalError,
+ code='57P05', name='IDLE_SESSION_TIMEOUT'):
+ pass
+
+
+# Class 58 - System Error (errors external to PostgreSQL itself)
+
+class SystemError(OperationalError,
+ code='58000', name='SYSTEM_ERROR'):
+ pass
+
+class IoError(OperationalError,
+ code='58030', name='IO_ERROR'):
+ pass
+
+class UndefinedFile(OperationalError,
+ code='58P01', name='UNDEFINED_FILE'):
+ pass
+
+class DuplicateFile(OperationalError,
+ code='58P02', name='DUPLICATE_FILE'):
+ pass
+
+
+# Class 72 - Snapshot Failure
+
+class SnapshotTooOld(DatabaseError,
+ code='72000', name='SNAPSHOT_TOO_OLD'):
+ pass
+
+
+# Class F0 - Configuration File Error
+
+class ConfigFileError(OperationalError,
+ code='F0000', name='CONFIG_FILE_ERROR'):
+ pass
+
+class LockFileExists(OperationalError,
+ code='F0001', name='LOCK_FILE_EXISTS'):
+ pass
+
+
+# Class HV - Foreign Data Wrapper Error (SQL/MED)
+
+class FdwError(OperationalError,
+ code='HV000', name='FDW_ERROR'):
+ pass
+
+class FdwOutOfMemory(OperationalError,
+ code='HV001', name='FDW_OUT_OF_MEMORY'):
+ pass
+
+class FdwDynamicParameterValueNeeded(OperationalError,
+ code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'):
+ pass
+
+class FdwInvalidDataType(OperationalError,
+ code='HV004', name='FDW_INVALID_DATA_TYPE'):
+ pass
+
+class FdwColumnNameNotFound(OperationalError,
+ code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'):
+ pass
+
+class FdwInvalidDataTypeDescriptors(OperationalError,
+ code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'):
+ pass
+
+class FdwInvalidColumnName(OperationalError,
+ code='HV007', name='FDW_INVALID_COLUMN_NAME'):
+ pass
+
+class FdwInvalidColumnNumber(OperationalError,
+ code='HV008', name='FDW_INVALID_COLUMN_NUMBER'):
+ pass
+
+class FdwInvalidUseOfNullPointer(OperationalError,
+ code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'):
+ pass
+
+class FdwInvalidStringFormat(OperationalError,
+ code='HV00A', name='FDW_INVALID_STRING_FORMAT'):
+ pass
+
+class FdwInvalidHandle(OperationalError,
+ code='HV00B', name='FDW_INVALID_HANDLE'):
+ pass
+
+class FdwInvalidOptionIndex(OperationalError,
+ code='HV00C', name='FDW_INVALID_OPTION_INDEX'):
+ pass
+
+class FdwInvalidOptionName(OperationalError,
+ code='HV00D', name='FDW_INVALID_OPTION_NAME'):
+ pass
+
+class FdwOptionNameNotFound(OperationalError,
+ code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'):
+ pass
+
+class FdwReplyHandle(OperationalError,
+ code='HV00K', name='FDW_REPLY_HANDLE'):
+ pass
+
+class FdwUnableToCreateExecution(OperationalError,
+ code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'):
+ pass
+
+class FdwUnableToCreateReply(OperationalError,
+ code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'):
+ pass
+
+class FdwUnableToEstablishConnection(OperationalError,
+ code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'):
+ pass
+
+class FdwNoSchemas(OperationalError,
+ code='HV00P', name='FDW_NO_SCHEMAS'):
+ pass
+
+class FdwSchemaNotFound(OperationalError,
+ code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'):
+ pass
+
+class FdwTableNotFound(OperationalError,
+ code='HV00R', name='FDW_TABLE_NOT_FOUND'):
+ pass
+
+class FdwFunctionSequenceError(OperationalError,
+ code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'):
+ pass
+
+class FdwTooManyHandles(OperationalError,
+ code='HV014', name='FDW_TOO_MANY_HANDLES'):
+ pass
+
+class FdwInconsistentDescriptorInformation(OperationalError,
+ code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'):
+ pass
+
+class FdwInvalidAttributeValue(OperationalError,
+ code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'):
+ pass
+
+class FdwInvalidStringLengthOrBufferLength(OperationalError,
+ code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'):
+ pass
+
+class FdwInvalidDescriptorFieldIdentifier(OperationalError,
+ code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'):
+ pass
+
+
+# Class P0 - PL/pgSQL Error
+
+class PlpgsqlError(ProgrammingError,
+ code='P0000', name='PLPGSQL_ERROR'):
+ pass
+
+class RaiseException(ProgrammingError,
+ code='P0001', name='RAISE_EXCEPTION'):
+ pass
+
+class NoDataFound(ProgrammingError,
+ code='P0002', name='NO_DATA_FOUND'):
+ pass
+
+class TooManyRows(ProgrammingError,
+ code='P0003', name='TOO_MANY_ROWS'):
+ pass
+
+class AssertFailure(ProgrammingError,
+ code='P0004', name='ASSERT_FAILURE'):
+ pass
+
+
+# Class XX - Internal Error
+
+class InternalError_(InternalError,
+ code='XX000', name='INTERNAL_ERROR'):
+ pass
+
+class DataCorrupted(InternalError,
+ code='XX001', name='DATA_CORRUPTED'):
+ pass
+
+class IndexCorrupted(InternalError,
+ code='XX002', name='INDEX_CORRUPTED'):
+ pass
+
+
+# autogenerated: end
+# fmt: on
diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py
new file mode 100644
index 0000000..584fe47
--- /dev/null
+++ b/psycopg/psycopg/generators.py
@@ -0,0 +1,320 @@
+"""
+Generators implementing communication protocols with the libpq
+
+Certain operations (connection, querying) are an interleave of libpq calls and
+waiting for the socket to be ready. This module contains the code to execute
+the operations, yielding a polling state whenever there is to wait. The
+functions in the `waiting` module are the ones who wait more or less
+cooperatively for the socket to be ready and make these generators continue.
+
+All these generators yield pairs (fileno, `Wait`) whenever an operation would
+block. The generator can be restarted sending the appropriate `Ready` state
+when the file descriptor is ready.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import List, Optional, Union
+
+from . import pq
+from . import errors as e
+from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
+from .pq.abc import PGconn, PGresult
+from .waiting import Wait, Ready
+from ._compat import Deque
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding, conninfo_encoding
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+POLL_OK = pq.PollingStatus.OK
+POLL_READING = pq.PollingStatus.READING
+POLL_WRITING = pq.PollingStatus.WRITING
+POLL_FAILED = pq.PollingStatus.FAILED
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+logger = logging.getLogger(__name__)
+
+
+def _connect(conninfo: str) -> PQGenConn[PGconn]:
+ """
+ Generator to create a database connection without blocking.
+
+ """
+ conn = pq.PGconn.connect_start(conninfo.encode())
+ while True:
+ if conn.status == BAD:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+
+ status = conn.connect_poll()
+ if status == POLL_OK:
+ break
+ elif status == POLL_READING:
+ yield conn.socket, WAIT_R
+ elif status == POLL_WRITING:
+ yield conn.socket, WAIT_W
+ elif status == POLL_FAILED:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection failed: {pq.error_message(conn, encoding=encoding)}",
+ pgconn=conn,
+ )
+ else:
+ raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn)
+
+ conn.nonblocking = 1
+ return conn
+
+
+def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator sending a query and returning results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ yield from _send(pgconn)
+ rv = yield from _fetch_many(pgconn)
+ return rv
+
+
+def _send(pgconn: PGconn) -> PQGen[None]:
+ """
+ Generator to send a query to the server without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ After this generator has finished you may want to cycle using `fetch()`
+ to retrieve the results available.
+ """
+ while True:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ ready = yield WAIT_RW
+ if ready & READY_R:
+ # This call may read notifies: they will be saved in the
+ # PGconn buffer and passed to Python later, in `fetch()`.
+ pgconn.consume_input()
+
+
+def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
+ """
+ Generator retrieving results from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ results: List[PGresult] = []
+ while True:
+ res = yield from _fetch(pgconn)
+ if not res:
+ break
+
+ results.append(res)
+ status = res.status
+ if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ # After entering copy mode the libpq will create a phony result
+ # for every request so let's break the endless loop.
+ break
+
+ if status == PIPELINE_SYNC:
+ # PIPELINE_SYNC is not followed by a NULL, but we return it alone
+ # similarly to other result sets.
+ assert len(results) == 1, results
+ break
+
+ return results
+
+
+def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
+ """
+ Generator retrieving a single result from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return a result from the database (whether success or error).
+ """
+ if pgconn.is_busy():
+ yield WAIT_R
+ while True:
+ pgconn.consume_input()
+ if not pgconn.is_busy():
+ break
+ yield WAIT_R
+
+ _consume_notifies(pgconn)
+
+ return pgconn.get_result()
+
+
+def _pipeline_communicate(
+ pgconn: PGconn, commands: Deque[PipelineCommand]
+) -> PQGen[List[List[PGresult]]]:
+ """Generator to send queries from a connection in pipeline mode while also
+ receiving results.
+
+ Return a list results, including single PIPELINE_SYNC elements.
+ """
+ results = []
+
+ while True:
+ ready = yield WAIT_RW
+
+ if ready & READY_R:
+ pgconn.consume_input()
+ _consume_notifies(pgconn)
+
+ res: List[PGresult] = []
+ while not pgconn.is_busy():
+ r = pgconn.get_result()
+ if r is None:
+ if not res:
+ break
+ results.append(res)
+ res = []
+ elif r.status == PIPELINE_SYNC:
+ assert not res
+ results.append([r])
+ else:
+ res.append(r)
+
+ if ready & READY_W:
+ pgconn.flush()
+ if not commands:
+ break
+ commands.popleft()()
+
+ return results
+
+
+def _consume_notifies(pgconn: PGconn) -> None:
+ # Consume notifies
+ while True:
+ n = pgconn.notifies()
+ if not n:
+ break
+ if pgconn.notify_handler:
+ pgconn.notify_handler(n)
+
+
+def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
+ yield WAIT_R
+ pgconn.consume_input()
+
+ ns = []
+ while True:
+ n = pgconn.notifies()
+ if n:
+ ns.append(n)
+ else:
+ break
+
+ return ns
+
+
+def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
+ while True:
+ nbytes, data = pgconn.get_copy_data(1)
+ if nbytes != 0:
+ break
+
+ # would block
+ yield WAIT_R
+ pgconn.consume_input()
+
+ if nbytes > 0:
+ # some data
+ return data
+
+ # Retrieve the final result of copy
+ results = yield from _fetch_many(pgconn)
+ if len(results) > 1:
+ # TODO: too brutal? Copy worked.
+ raise e.ProgrammingError("you cannot mix COPY with other operations")
+ result = results[0]
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
+ # Retry enqueuing data until successful.
+ #
+ # WARNING! This can cause an infinite loop if the buffer is too large. (see
+ # ticket #255). We avoid it in the Copy object by splitting a large buffer
+ # into smaller ones. We prefer to do it there instead of here in order to
+ # do it upstream the queue decoupling the writer task from the producer one.
+ while pgconn.put_copy_data(buffer) == 0:
+ yield WAIT_W
+
+
+def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
+ # Retry enqueuing end copy message until successful
+ while pgconn.put_copy_end(error) == 0:
+ yield WAIT_W
+
+ # Repeat until it the message is flushed to the server
+ while True:
+ yield WAIT_W
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ # Retrieve the final result of copy
+ (result,) = yield from _fetch_many(pgconn)
+ if result.status != COMMAND_OK:
+ encoding = pgconn_encoding(pgconn)
+ raise e.error_from_result(result, encoding=encoding)
+
+ return result
+
+
+# Override functions with fast versions if available
+if _psycopg:
+ connect = _psycopg.connect
+ execute = _psycopg.execute
+ send = _psycopg.send
+ fetch_many = _psycopg.fetch_many
+ fetch = _psycopg.fetch
+ pipeline_communicate = _psycopg.pipeline_communicate
+
+else:
+ connect = _connect
+ execute = _execute
+ send = _send
+ fetch_many = _fetch_many
+ fetch = _fetch
+ pipeline_communicate = _pipeline_communicate
diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py
new file mode 100644
index 0000000..792a9c8
--- /dev/null
+++ b/psycopg/psycopg/postgres.py
@@ -0,0 +1,125 @@
+"""
+Types configuration specific to PostgreSQL.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
+from .abc import AdaptContext
+from ._adapters_map import AdaptersMap
+
+# Global objects with PostgreSQL builtins and globally registered user types.
+types = TypesRegistry()
+
+# Global adapter maps with PostgreSQL types configuration
+adapters = AdaptersMap(types=types)
+
+# Use tools/update_oids.py to update this data.
+for t in [
+ TypeInfo('"char"', 18, 1002),
+ # autogenerated: start
+ # Generated from PostgreSQL 15.0
+ TypeInfo("aclitem", 1033, 1034),
+ TypeInfo("bit", 1560, 1561),
+ TypeInfo("bool", 16, 1000, regtype="boolean"),
+ TypeInfo("box", 603, 1020, delimiter=";"),
+ TypeInfo("bpchar", 1042, 1014, regtype="character"),
+ TypeInfo("bytea", 17, 1001),
+ TypeInfo("cid", 29, 1012),
+ TypeInfo("cidr", 650, 651),
+ TypeInfo("circle", 718, 719),
+ TypeInfo("date", 1082, 1182),
+ TypeInfo("float4", 700, 1021, regtype="real"),
+ TypeInfo("float8", 701, 1022, regtype="double precision"),
+ TypeInfo("gtsvector", 3642, 3644),
+ TypeInfo("inet", 869, 1041),
+ TypeInfo("int2", 21, 1005, regtype="smallint"),
+ TypeInfo("int2vector", 22, 1006),
+ TypeInfo("int4", 23, 1007, regtype="integer"),
+ TypeInfo("int8", 20, 1016, regtype="bigint"),
+ TypeInfo("interval", 1186, 1187),
+ TypeInfo("json", 114, 199),
+ TypeInfo("jsonb", 3802, 3807),
+ TypeInfo("jsonpath", 4072, 4073),
+ TypeInfo("line", 628, 629),
+ TypeInfo("lseg", 601, 1018),
+ TypeInfo("macaddr", 829, 1040),
+ TypeInfo("macaddr8", 774, 775),
+ TypeInfo("money", 790, 791),
+ TypeInfo("name", 19, 1003),
+ TypeInfo("numeric", 1700, 1231),
+ TypeInfo("oid", 26, 1028),
+ TypeInfo("oidvector", 30, 1013),
+ TypeInfo("path", 602, 1019),
+ TypeInfo("pg_lsn", 3220, 3221),
+ TypeInfo("point", 600, 1017),
+ TypeInfo("polygon", 604, 1027),
+ TypeInfo("record", 2249, 2287),
+ TypeInfo("refcursor", 1790, 2201),
+ TypeInfo("regclass", 2205, 2210),
+ TypeInfo("regcollation", 4191, 4192),
+ TypeInfo("regconfig", 3734, 3735),
+ TypeInfo("regdictionary", 3769, 3770),
+ TypeInfo("regnamespace", 4089, 4090),
+ TypeInfo("regoper", 2203, 2208),
+ TypeInfo("regoperator", 2204, 2209),
+ TypeInfo("regproc", 24, 1008),
+ TypeInfo("regprocedure", 2202, 2207),
+ TypeInfo("regrole", 4096, 4097),
+ TypeInfo("regtype", 2206, 2211),
+ TypeInfo("text", 25, 1009),
+ TypeInfo("tid", 27, 1010),
+ TypeInfo("time", 1083, 1183, regtype="time without time zone"),
+ TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
+ TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
+ TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
+ TypeInfo("tsquery", 3615, 3645),
+ TypeInfo("tsvector", 3614, 3643),
+ TypeInfo("txid_snapshot", 2970, 2949),
+ TypeInfo("uuid", 2950, 2951),
+ TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
+ TypeInfo("varchar", 1043, 1015, regtype="character varying"),
+ TypeInfo("xid", 28, 1011),
+ TypeInfo("xid8", 5069, 271),
+ TypeInfo("xml", 142, 143),
+ RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
+ RangeInfo("int4range", 3904, 3905, subtype_oid=23),
+ RangeInfo("int8range", 3926, 3927, subtype_oid=20),
+ RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
+ RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
+ RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
+ MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
+ MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
+ MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
+ MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
+ MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
+ MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
+ # autogenerated: end
+]:
+ types.add(t)
+
+
+# A few oids used a bit everywhere
+INVALID_OID = 0
+TEXT_OID = types["text"].oid
+TEXT_ARRAY_OID = types["text"].array_oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+
+ from .types import array, bool, composite, datetime, enum, json, multirange
+ from .types import net, none, numeric, range, string, uuid
+
+ array.register_default_adapters(context)
+ bool.register_default_adapters(context)
+ composite.register_default_adapters(context)
+ datetime.register_default_adapters(context)
+ enum.register_default_adapters(context)
+ json.register_default_adapters(context)
+ multirange.register_default_adapters(context)
+ net.register_default_adapters(context)
+ none.register_default_adapters(context)
+ numeric.register_default_adapters(context)
+ range.register_default_adapters(context)
+ string.register_default_adapters(context)
+ uuid.register_default_adapters(context)
diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py
new file mode 100644
index 0000000..d5180b1
--- /dev/null
+++ b/psycopg/psycopg/pq/__init__.py
@@ -0,0 +1,133 @@
+"""
+psycopg libpq wrapper
+
+This package exposes the libpq functionalities as Python objects and functions.
+
+The real implementation (the binding to the C library) is
+implementation-dependant but all the implementations share the same interface.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import logging
+from typing import Callable, List, Type
+
+from . import abc
+from .misc import ConninfoOption, PGnotify, PGresAttDesc
+from .misc import error_message
+from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
+from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
+
+logger = logging.getLogger(__name__)
+
+__impl__: str
+"""The currently loaded implementation of the `!psycopg.pq` package.
+
+Possible values include ``python``, ``c``, ``binary``.
+"""
+
+__build_version__: int
+"""The libpq version the C package was built with.
+
+A number in the same format of `~psycopg.ConnectionInfo.server_version`
+representing the libpq used to build the speedup module (``c``, ``binary``) if
+available.
+
+Certain features might not be available if the built version is too old.
+"""
+
+version: Callable[[], int]
+PGconn: Type[abc.PGconn]
+PGresult: Type[abc.PGresult]
+Conninfo: Type[abc.Conninfo]
+Escaping: Type[abc.Escaping]
+PGcancel: Type[abc.PGcancel]
+
+
+def import_from_libpq() -> None:
+ """
+ Import pq objects implementation from the best libpq wrapper available.
+
+ If an implementation is requested try to import only it, otherwise
+ try to import the best implementation available.
+ """
+ # import these names into the module on success as side effect
+ global __impl__, version, __build_version__
+ global PGconn, PGresult, Conninfo, Escaping, PGcancel
+
+ impl = os.environ.get("PSYCOPG_IMPL", "").lower()
+ module = None
+ attempts: List[str] = []
+
+ def handle_error(name: str, e: Exception) -> None:
+ if not impl:
+ msg = f"couldn't import psycopg '{name}' implementation: {e}"
+ logger.debug(msg)
+ attempts.append(msg)
+ else:
+ msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
+ raise ImportError(msg) from e
+
+ # The best implementation: fast but requires the system libpq installed
+ if not impl or impl == "c":
+ try:
+ from psycopg_c import pq as module # type: ignore
+ except Exception as e:
+ handle_error("c", e)
+
+ # Second best implementation: fast and stand-alone
+ if not module and (not impl or impl == "binary"):
+ try:
+ from psycopg_binary import pq as module # type: ignore
+ except Exception as e:
+ handle_error("binary", e)
+
+ # Pure Python implementation, slow and requires the system libpq installed.
+ if not module and (not impl or impl == "python"):
+ try:
+ from . import pq_ctypes as module # type: ignore[no-redef]
+ except Exception as e:
+ handle_error("python", e)
+
+ if module:
+ __impl__ = module.__impl__
+ version = module.version
+ PGconn = module.PGconn
+ PGresult = module.PGresult
+ Conninfo = module.Conninfo
+ Escaping = module.Escaping
+ PGcancel = module.PGcancel
+ __build_version__ = module.__build_version__
+ elif impl:
+ raise ImportError(f"requested psycopg implementation '{impl}' unknown")
+ else:
+ sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
+ raise ImportError(
+ f"""\
+no pq wrapper available.
+Attempts made:
+{sattempts}"""
+ )
+
+
+import_from_libpq()
+
+__all__ = (
+ "ConnStatus",
+ "PipelineStatus",
+ "PollingStatus",
+ "TransactionStatus",
+ "ExecStatus",
+ "Ping",
+ "DiagnosticField",
+ "Format",
+ "Trace",
+ "PGconn",
+ "PGnotify",
+ "Conninfo",
+ "PGresAttDesc",
+ "error_message",
+ "ConninfoOption",
+ "version",
+)
diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py
new file mode 100644
index 0000000..f35d09f
--- /dev/null
+++ b/psycopg/psycopg/pq/_debug.py
@@ -0,0 +1,106 @@
+"""
+libpq debugging tools
+
+These functionalities are exposed here for convenience, but are not part of
+the public interface and are subject to change at any moment.
+
+Suggested usage::
+
+ import logging
+ import psycopg
+ from psycopg import pq
+ from psycopg.pq._debug import PGconnDebug
+
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
+ logger = logging.getLogger("psycopg.debug")
+ logger.setLevel(logging.INFO)
+
+ assert pq.__impl__ == "python"
+ pq.PGconn = PGconnDebug
+
+ with psycopg.connect("") as conn:
+ conn.pgconn.trace(2)
+ conn.pgconn.set_trace_flags(
+ pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
+ ...
+
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import inspect
+import logging
+from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING
+from functools import wraps
+
+from . import PGconn
+from .misc import connection_summary
+
+if TYPE_CHECKING:
+ from . import abc
+
+Func = TypeVar("Func", bound=Callable[..., Any])
+
+logger = logging.getLogger("psycopg.debug")
+
+
+class PGconnDebug:
+ """Wrapper for a PQconn logging all its access."""
+
+ _Self = TypeVar("_Self", bound="PGconnDebug")
+ _pgconn: "abc.PGconn"
+
+ def __init__(self, pgconn: "abc.PGconn"):
+ super().__setattr__("_pgconn", pgconn)
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self._pgconn)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ def __getattr__(self, attr: str) -> Any:
+ value = getattr(self._pgconn, attr)
+ if callable(value):
+ return debugging(value)
+ else:
+ logger.info("PGconn.%s -> %s", attr, value)
+ return value
+
+ def __setattr__(self, attr: str, value: Any) -> None:
+ setattr(self._pgconn, attr, value)
+ logger.info("PGconn.%s <- %s", attr, value)
+
+ @classmethod
+ def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect)(conninfo))
+
+ @classmethod
+ def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
+ return cls(debugging(PGconn.connect_start)(conninfo))
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ return debugging(PGconn.ping)(conninfo)
+
+
+def debugging(f: Func) -> Func:
+ """Wrap a function in order to log its arguments and return value on call."""
+
+ @wraps(f)
+ def debugging_(*args: Any, **kwargs: Any) -> Any:
+ reprs = []
+ for arg in args:
+ reprs.append(f"{arg!r}")
+ for (k, v) in kwargs.items():
+ reprs.append(f"{k}={v!r}")
+
+ logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
+ rv = f(*args, **kwargs)
+ # Display the return value only if the function is declared to return
+ # something else than None.
+ ra = inspect.signature(f).return_annotation
+ if ra is not None or rv is not None:
+ logger.info(" <- %r", rv)
+ return rv
+
+ return debugging_ # type: ignore
diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py
new file mode 100644
index 0000000..e0d4018
--- /dev/null
+++ b/psycopg/psycopg/pq/_enums.py
@@ -0,0 +1,249 @@
+"""
+libpq enum definitions for psycopg
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from enum import IntEnum, IntFlag, auto
+
+
+class ConnStatus(IntEnum):
+ """
+ Current status of the connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """The connection is in a working state."""
+ BAD = auto()
+ """The connection is closed."""
+
+ STARTED = auto()
+ MADE = auto()
+ AWAITING_RESPONSE = auto()
+ AUTH_OK = auto()
+ SETENV = auto()
+ SSL_STARTUP = auto()
+ NEEDED = auto()
+ CHECK_WRITABLE = auto()
+ CONSUME = auto()
+ GSS_STARTUP = auto()
+ CHECK_TARGET = auto()
+ CHECK_STANDBY = auto()
+
+
+class PollingStatus(IntEnum):
+ """
+ The status of the socket during a connection.
+
+ If ``READING`` or ``WRITING`` you may select before polling again.
+ """
+
+ __module__ = "psycopg.pq"
+
+ FAILED = 0
+ """Connection attempt failed."""
+ READING = auto()
+ """Will have to wait before reading new data."""
+ WRITING = auto()
+ """Will have to wait before writing new data."""
+ OK = auto()
+ """Connection completed."""
+
+ ACTIVE = auto()
+
+
+class ExecStatus(IntEnum):
+ """
+ The status of a command.
+ """
+
+ __module__ = "psycopg.pq"
+
+ EMPTY_QUERY = 0
+ """The string sent to the server was empty."""
+
+ COMMAND_OK = auto()
+ """Successful completion of a command returning no data."""
+
+ TUPLES_OK = auto()
+ """
+ Successful completion of a command returning data (such as a SELECT or SHOW).
+ """
+
+ COPY_OUT = auto()
+ """Copy Out (from server) data transfer started."""
+
+ COPY_IN = auto()
+ """Copy In (to server) data transfer started."""
+
+ BAD_RESPONSE = auto()
+ """The server's response was not understood."""
+
+ NONFATAL_ERROR = auto()
+ """A nonfatal error (a notice or warning) occurred."""
+
+ FATAL_ERROR = auto()
+ """A fatal error occurred."""
+
+ COPY_BOTH = auto()
+ """
+ Copy In/Out (to and from server) data transfer started.
+
+ This feature is currently used only for streaming replication, so this
+ status should not occur in ordinary applications.
+ """
+
+ SINGLE_TUPLE = auto()
+ """
+ The PGresult contains a single result tuple from the current command.
+
+ This status occurs only when single-row mode has been selected for the
+ query.
+ """
+
+ PIPELINE_SYNC = auto()
+ """
+ The PGresult represents a synchronization point in pipeline mode,
+ requested by PQpipelineSync.
+
+ This status occurs only when pipeline mode has been selected.
+ """
+
+ PIPELINE_ABORTED = auto()
+ """
+ The PGresult represents a pipeline that has received an error from the server.
+
+ PQgetResult must be called repeatedly, and each time it will return this
+ status code until the end of the current pipeline, at which point it will
+ return PGRES_PIPELINE_SYNC and normal processing can resume.
+ """
+
+
+class TransactionStatus(IntEnum):
+ """
+ The transaction status of a connection.
+ """
+
+ __module__ = "psycopg.pq"
+
+ IDLE = 0
+ """Connection ready, no transaction active."""
+
+ ACTIVE = auto()
+ """A command is in progress."""
+
+ INTRANS = auto()
+ """Connection idle in an open transaction."""
+
+ INERROR = auto()
+ """An error happened in the current transaction."""
+
+ UNKNOWN = auto()
+ """Unknown connection state, broken connection."""
+
+
+class Ping(IntEnum):
+ """Response from a ping attempt."""
+
+ __module__ = "psycopg.pq"
+
+ OK = 0
+ """
+ The server is running and appears to be accepting connections.
+ """
+
+ REJECT = auto()
+ """
+ The server is running but is in a state that disallows connections.
+ """
+
+ NO_RESPONSE = auto()
+ """
+ The server could not be contacted.
+ """
+
+ NO_ATTEMPT = auto()
+ """
+ No attempt was made to contact the server.
+ """
+
+
+class PipelineStatus(IntEnum):
+ """Pipeline mode status of the libpq connection."""
+
+ __module__ = "psycopg.pq"
+
+ OFF = 0
+ """
+ The libpq connection is *not* in pipeline mode.
+ """
+ ON = auto()
+ """
+ The libpq connection is in pipeline mode.
+ """
+ ABORTED = auto()
+ """
+ The libpq connection is in pipeline mode and an error occurred while
+ processing the current pipeline. The aborted flag is cleared when
+ PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
+ """
+
+
+class DiagnosticField(IntEnum):
+ """
+ Fields in an error report.
+ """
+
+ __module__ = "psycopg.pq"
+
+ # from postgres_ext.h
+ SEVERITY = ord("S")
+ SEVERITY_NONLOCALIZED = ord("V")
+ SQLSTATE = ord("C")
+ MESSAGE_PRIMARY = ord("M")
+ MESSAGE_DETAIL = ord("D")
+ MESSAGE_HINT = ord("H")
+ STATEMENT_POSITION = ord("P")
+ INTERNAL_POSITION = ord("p")
+ INTERNAL_QUERY = ord("q")
+ CONTEXT = ord("W")
+ SCHEMA_NAME = ord("s")
+ TABLE_NAME = ord("t")
+ COLUMN_NAME = ord("c")
+ DATATYPE_NAME = ord("d")
+ CONSTRAINT_NAME = ord("n")
+ SOURCE_FILE = ord("F")
+ SOURCE_LINE = ord("L")
+ SOURCE_FUNCTION = ord("R")
+
+
+class Format(IntEnum):
+ """
+ Enum representing the format of a query argument or return value.
+
+ These values are only the ones managed by the libpq. `~psycopg` may also
+ support automatically-chosen values: see `psycopg.adapt.PyFormat`.
+ """
+
+ __module__ = "psycopg.pq"
+
+ TEXT = 0
+ """Text parameter."""
+ BINARY = 1
+ """Binary parameter."""
+
+
+class Trace(IntFlag):
+ """
+ Enum to control tracing of the client/server communication.
+ """
+
+ __module__ = "psycopg.pq"
+
+ SUPPRESS_TIMESTAMPS = 1
+ """Do not include timestamps in messages."""
+
+ REGRESS_MODE = 2
+ """Redact some fields, e.g. OIDs, from messages."""
diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py
new file mode 100644
index 0000000..9ca1d12
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.py
@@ -0,0 +1,804 @@
+"""
+libpq access using ctypes
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import ctypes
+import ctypes.util
+from ctypes import Structure, CFUNCTYPE, POINTER
+from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
+from typing import List, Optional, Tuple
+
+from .misc import find_libpq_full_path
+from ..errors import NotSupportedError
+
+libname = find_libpq_full_path()
+if not libname:
+ raise ImportError("libpq library not found")
+
+pq = ctypes.cdll.LoadLibrary(libname)
+
+
+class FILE(Structure):
+ pass
+
+
+FILE_ptr = POINTER(FILE)
+
+if sys.platform == "linux":
+ libcname = ctypes.util.find_library("c")
+ assert libcname
+ libc = ctypes.cdll.LoadLibrary(libcname)
+
+ fdopen = libc.fdopen
+ fdopen.argtypes = (c_int, c_char_p)
+ fdopen.restype = FILE_ptr
+
+
+# Get the libpq version to define what functions are available.
+
+PQlibVersion = pq.PQlibVersion
+PQlibVersion.argtypes = []
+PQlibVersion.restype = c_int
+
+libpq_version = PQlibVersion()
+
+
+# libpq data types
+
+
+Oid = c_uint
+
+
+class PGconn_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresult_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PQconninfoOption_struct(Structure):
+ _fields_ = [
+ ("keyword", c_char_p),
+ ("envvar", c_char_p),
+ ("compiled", c_char_p),
+ ("val", c_char_p),
+ ("label", c_char_p),
+ ("dispchar", c_char_p),
+ ("dispsize", c_int),
+ ]
+
+
+class PGnotify_struct(Structure):
+ _fields_ = [
+ ("relname", c_char_p),
+ ("be_pid", c_int),
+ ("extra", c_char_p),
+ ]
+
+
+class PGcancel_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
+class PGresAttDesc_struct(Structure):
+ _fields_ = [
+ ("name", c_char_p),
+ ("tableid", Oid),
+ ("columnid", c_int),
+ ("format", c_int),
+ ("typid", Oid),
+ ("typlen", c_int),
+ ("atttypmod", c_int),
+ ]
+
+
+PGconn_ptr = POINTER(PGconn_struct)
+PGresult_ptr = POINTER(PGresult_struct)
+PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
+PGnotify_ptr = POINTER(PGnotify_struct)
+PGcancel_ptr = POINTER(PGcancel_struct)
+PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
+
+
+# Function definitions as explained in PostgreSQL 12 documentation
+
+# 33.1. Database Connection Control Functions
+
+# PQconnectdbParams: doesn't seem useful, won't wrap for now
+
+PQconnectdb = pq.PQconnectdb
+PQconnectdb.argtypes = [c_char_p]
+PQconnectdb.restype = PGconn_ptr
+
+# PQsetdbLogin: not useful
+# PQsetdb: not useful
+
+# PQconnectStartParams: not useful
+
+PQconnectStart = pq.PQconnectStart
+PQconnectStart.argtypes = [c_char_p]
+PQconnectStart.restype = PGconn_ptr
+
+PQconnectPoll = pq.PQconnectPoll
+PQconnectPoll.argtypes = [PGconn_ptr]
+PQconnectPoll.restype = c_int
+
+PQconndefaults = pq.PQconndefaults
+PQconndefaults.argtypes = []
+PQconndefaults.restype = PQconninfoOption_ptr
+
+PQconninfoFree = pq.PQconninfoFree
+PQconninfoFree.argtypes = [PQconninfoOption_ptr]
+PQconninfoFree.restype = None
+
+PQconninfo = pq.PQconninfo
+PQconninfo.argtypes = [PGconn_ptr]
+PQconninfo.restype = PQconninfoOption_ptr
+
+PQconninfoParse = pq.PQconninfoParse
+PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
+PQconninfoParse.restype = PQconninfoOption_ptr
+
+PQfinish = pq.PQfinish
+PQfinish.argtypes = [PGconn_ptr]
+PQfinish.restype = None
+
+PQreset = pq.PQreset
+PQreset.argtypes = [PGconn_ptr]
+PQreset.restype = None
+
+PQresetStart = pq.PQresetStart
+PQresetStart.argtypes = [PGconn_ptr]
+PQresetStart.restype = c_int
+
+PQresetPoll = pq.PQresetPoll
+PQresetPoll.argtypes = [PGconn_ptr]
+PQresetPoll.restype = c_int
+
+PQping = pq.PQping
+PQping.argtypes = [c_char_p]
+PQping.restype = c_int
+
+
+# 33.2. Connection Status Functions
+
+PQdb = pq.PQdb
+PQdb.argtypes = [PGconn_ptr]
+PQdb.restype = c_char_p
+
+PQuser = pq.PQuser
+PQuser.argtypes = [PGconn_ptr]
+PQuser.restype = c_char_p
+
+PQpass = pq.PQpass
+PQpass.argtypes = [PGconn_ptr]
+PQpass.restype = c_char_p
+
+PQhost = pq.PQhost
+PQhost.argtypes = [PGconn_ptr]
+PQhost.restype = c_char_p
+
+_PQhostaddr = None
+
+if libpq_version >= 120000:
+ _PQhostaddr = pq.PQhostaddr
+ _PQhostaddr.argtypes = [PGconn_ptr]
+ _PQhostaddr.restype = c_char_p
+
+
+def PQhostaddr(pgconn: PGconn_struct) -> bytes:
+ if not _PQhostaddr:
+ raise NotSupportedError(
+ "PQhostaddr requires libpq from PostgreSQL 12,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQhostaddr(pgconn)
+
+
+PQport = pq.PQport
+PQport.argtypes = [PGconn_ptr]
+PQport.restype = c_char_p
+
+PQtty = pq.PQtty
+PQtty.argtypes = [PGconn_ptr]
+PQtty.restype = c_char_p
+
+PQoptions = pq.PQoptions
+PQoptions.argtypes = [PGconn_ptr]
+PQoptions.restype = c_char_p
+
+PQstatus = pq.PQstatus
+PQstatus.argtypes = [PGconn_ptr]
+PQstatus.restype = c_int
+
+PQtransactionStatus = pq.PQtransactionStatus
+PQtransactionStatus.argtypes = [PGconn_ptr]
+PQtransactionStatus.restype = c_int
+
+PQparameterStatus = pq.PQparameterStatus
+PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
+PQparameterStatus.restype = c_char_p
+
+PQprotocolVersion = pq.PQprotocolVersion
+PQprotocolVersion.argtypes = [PGconn_ptr]
+PQprotocolVersion.restype = c_int
+
+PQserverVersion = pq.PQserverVersion
+PQserverVersion.argtypes = [PGconn_ptr]
+PQserverVersion.restype = c_int
+
+PQerrorMessage = pq.PQerrorMessage
+PQerrorMessage.argtypes = [PGconn_ptr]
+PQerrorMessage.restype = c_char_p
+
+PQsocket = pq.PQsocket
+PQsocket.argtypes = [PGconn_ptr]
+PQsocket.restype = c_int
+
+PQbackendPID = pq.PQbackendPID
+PQbackendPID.argtypes = [PGconn_ptr]
+PQbackendPID.restype = c_int
+
+PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
+PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
+PQconnectionNeedsPassword.restype = c_int
+
+PQconnectionUsedPassword = pq.PQconnectionUsedPassword
+PQconnectionUsedPassword.argtypes = [PGconn_ptr]
+PQconnectionUsedPassword.restype = c_int
+
+PQsslInUse = pq.PQsslInUse
+PQsslInUse.argtypes = [PGconn_ptr]
+PQsslInUse.restype = c_int
+
+# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
+
+
+# 33.3. Command Execution Functions
+
+PQexec = pq.PQexec
+PQexec.argtypes = [PGconn_ptr, c_char_p]
+PQexec.restype = PGresult_ptr
+
+PQexecParams = pq.PQexecParams
+PQexecParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecParams.restype = PGresult_ptr
+
+PQprepare = pq.PQprepare
+PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQprepare.restype = PGresult_ptr
+
+PQexecPrepared = pq.PQexecPrepared
+PQexecPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQexecPrepared.restype = PGresult_ptr
+
+PQdescribePrepared = pq.PQdescribePrepared
+PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePrepared.restype = PGresult_ptr
+
+PQdescribePortal = pq.PQdescribePortal
+PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQdescribePortal.restype = PGresult_ptr
+
+PQresultStatus = pq.PQresultStatus
+PQresultStatus.argtypes = [PGresult_ptr]
+PQresultStatus.restype = c_int
+
+# PQresStatus: not needed, we have pretty enums
+
+PQresultErrorMessage = pq.PQresultErrorMessage
+PQresultErrorMessage.argtypes = [PGresult_ptr]
+PQresultErrorMessage.restype = c_char_p
+
+# TODO: PQresultVerboseErrorMessage
+
+PQresultErrorField = pq.PQresultErrorField
+PQresultErrorField.argtypes = [PGresult_ptr, c_int]
+PQresultErrorField.restype = c_char_p
+
+PQclear = pq.PQclear
+PQclear.argtypes = [PGresult_ptr]
+PQclear.restype = None
+
+
+# 33.3.2. Retrieving Query Result Information
+
+PQntuples = pq.PQntuples
+PQntuples.argtypes = [PGresult_ptr]
+PQntuples.restype = c_int
+
+PQnfields = pq.PQnfields
+PQnfields.argtypes = [PGresult_ptr]
+PQnfields.restype = c_int
+
+PQfname = pq.PQfname
+PQfname.argtypes = [PGresult_ptr, c_int]
+PQfname.restype = c_char_p
+
+# PQfnumber: useless and hard to use
+
+PQftable = pq.PQftable
+PQftable.argtypes = [PGresult_ptr, c_int]
+PQftable.restype = Oid
+
+PQftablecol = pq.PQftablecol
+PQftablecol.argtypes = [PGresult_ptr, c_int]
+PQftablecol.restype = c_int
+
+PQfformat = pq.PQfformat
+PQfformat.argtypes = [PGresult_ptr, c_int]
+PQfformat.restype = c_int
+
+PQftype = pq.PQftype
+PQftype.argtypes = [PGresult_ptr, c_int]
+PQftype.restype = Oid
+
+PQfmod = pq.PQfmod
+PQfmod.argtypes = [PGresult_ptr, c_int]
+PQfmod.restype = c_int
+
+PQfsize = pq.PQfsize
+PQfsize.argtypes = [PGresult_ptr, c_int]
+PQfsize.restype = c_int
+
+PQbinaryTuples = pq.PQbinaryTuples
+PQbinaryTuples.argtypes = [PGresult_ptr]
+PQbinaryTuples.restype = c_int
+
+PQgetvalue = pq.PQgetvalue
+PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
+
+PQgetisnull = pq.PQgetisnull
+PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetisnull.restype = c_int
+
+PQgetlength = pq.PQgetlength
+PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
+PQgetlength.restype = c_int
+
+PQnparams = pq.PQnparams
+PQnparams.argtypes = [PGresult_ptr]
+PQnparams.restype = c_int
+
+PQparamtype = pq.PQparamtype
+PQparamtype.argtypes = [PGresult_ptr, c_int]
+PQparamtype.restype = Oid
+
+# PQprint: pretty useless
+
+# 33.3.3. Retrieving Other Result Information
+
+PQcmdStatus = pq.PQcmdStatus
+PQcmdStatus.argtypes = [PGresult_ptr]
+PQcmdStatus.restype = c_char_p
+
+PQcmdTuples = pq.PQcmdTuples
+PQcmdTuples.argtypes = [PGresult_ptr]
+PQcmdTuples.restype = c_char_p
+
+PQoidValue = pq.PQoidValue
+PQoidValue.argtypes = [PGresult_ptr]
+PQoidValue.restype = Oid
+
+
+# 33.3.4. Escaping Strings for Inclusion in SQL Commands
+
+PQescapeLiteral = pq.PQescapeLiteral
+PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeLiteral.restype = POINTER(c_char)
+
+PQescapeIdentifier = pq.PQescapeIdentifier
+PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeIdentifier.restype = POINTER(c_char)
+
+PQescapeStringConn = pq.PQescapeStringConn
+# TODO: raises "wrong type" error
+# PQescapeStringConn.argtypes = [
+# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
+# ]
+PQescapeStringConn.restype = c_size_t
+
+PQescapeString = pq.PQescapeString
+# TODO: raises "wrong type" error
+# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
+PQescapeString.restype = c_size_t
+
+PQescapeByteaConn = pq.PQescapeByteaConn
+PQescapeByteaConn.argtypes = [
+ PGconn_ptr,
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeByteaConn.restype = POINTER(c_ubyte)
+
+PQescapeBytea = pq.PQescapeBytea
+PQescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ c_size_t,
+ POINTER(c_size_t),
+]
+PQescapeBytea.restype = POINTER(c_ubyte)
+
+
+PQunescapeBytea = pq.PQunescapeBytea
+PQunescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ POINTER(c_size_t),
+]
+PQunescapeBytea.restype = POINTER(c_ubyte)
+
+
+# 33.4. Asynchronous Command Processing
+
+PQsendQuery = pq.PQsendQuery
+PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
+PQsendQuery.restype = c_int
+
+PQsendQueryParams = pq.PQsendQueryParams
+PQsendQueryParams.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(Oid),
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryParams.restype = c_int
+
+PQsendPrepare = pq.PQsendPrepare
+PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
+PQsendPrepare.restype = c_int
+
+PQsendQueryPrepared = pq.PQsendQueryPrepared
+PQsendQueryPrepared.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_int,
+ POINTER(c_char_p),
+ POINTER(c_int),
+ POINTER(c_int),
+ c_int,
+]
+PQsendQueryPrepared.restype = c_int
+
+PQsendDescribePrepared = pq.PQsendDescribePrepared
+PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePrepared.restype = c_int
+
+PQsendDescribePortal = pq.PQsendDescribePortal
+PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
+PQsendDescribePortal.restype = c_int
+
+PQgetResult = pq.PQgetResult
+PQgetResult.argtypes = [PGconn_ptr]
+PQgetResult.restype = PGresult_ptr
+
+PQconsumeInput = pq.PQconsumeInput
+PQconsumeInput.argtypes = [PGconn_ptr]
+PQconsumeInput.restype = c_int
+
+PQisBusy = pq.PQisBusy
+PQisBusy.argtypes = [PGconn_ptr]
+PQisBusy.restype = c_int
+
+PQsetnonblocking = pq.PQsetnonblocking
+PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
+PQsetnonblocking.restype = c_int
+
+PQisnonblocking = pq.PQisnonblocking
+PQisnonblocking.argtypes = [PGconn_ptr]
+PQisnonblocking.restype = c_int
+
+PQflush = pq.PQflush
+PQflush.argtypes = [PGconn_ptr]
+PQflush.restype = c_int
+
+
+# 33.5. Retrieving Query Results Row-by-Row
+PQsetSingleRowMode = pq.PQsetSingleRowMode
+PQsetSingleRowMode.argtypes = [PGconn_ptr]
+PQsetSingleRowMode.restype = c_int
+
+
+# 33.6. Canceling Queries in Progress
+
+PQgetCancel = pq.PQgetCancel
+PQgetCancel.argtypes = [PGconn_ptr]
+PQgetCancel.restype = PGcancel_ptr
+
+PQfreeCancel = pq.PQfreeCancel
+PQfreeCancel.argtypes = [PGcancel_ptr]
+PQfreeCancel.restype = None
+
+PQcancel = pq.PQcancel
+# TODO: raises "wrong type" error
+# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
+PQcancel.restype = c_int
+
+
+# 33.8. Asynchronous Notification
+
+PQnotifies = pq.PQnotifies
+PQnotifies.argtypes = [PGconn_ptr]
+PQnotifies.restype = PGnotify_ptr
+
+
+# 33.9. Functions Associated with the COPY Command
+
+PQputCopyData = pq.PQputCopyData
+PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
+PQputCopyData.restype = c_int
+
+PQputCopyEnd = pq.PQputCopyEnd
+PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
+PQputCopyEnd.restype = c_int
+
+PQgetCopyData = pq.PQgetCopyData
+PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
+PQgetCopyData.restype = c_int
+
+
+# 33.10. Control Functions
+
+PQtrace = pq.PQtrace
+PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
+PQtrace.restype = None
+
+_PQsetTraceFlags = None
+
+if libpq_version >= 140000:
+ _PQsetTraceFlags = pq.PQsetTraceFlags
+ _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
+ _PQsetTraceFlags.restype = None
+
+
+def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
+ if not _PQsetTraceFlags:
+ raise NotSupportedError(
+ "PQsetTraceFlags requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+
+ _PQsetTraceFlags(pgconn, flags)
+
+
+PQuntrace = pq.PQuntrace
+PQuntrace.argtypes = [PGconn_ptr]
+PQuntrace.restype = None
+
+# 33.11. Miscellaneous Functions
+
+PQfreemem = pq.PQfreemem
+PQfreemem.argtypes = [c_void_p]
+PQfreemem.restype = None
+
+if libpq_version >= 100000:
+ _PQencryptPasswordConn = pq.PQencryptPasswordConn
+ _PQencryptPasswordConn.argtypes = [
+ PGconn_ptr,
+ c_char_p,
+ c_char_p,
+ c_char_p,
+ ]
+ _PQencryptPasswordConn.restype = POINTER(c_char)
+
+
+def PQencryptPasswordConn(
+ pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
+) -> Optional[bytes]:
+ if not _PQencryptPasswordConn:
+ raise NotSupportedError(
+ "PQencryptPasswordConn requires libpq from PostgreSQL 10,"
+ f" {libpq_version} available instead"
+ )
+
+ return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
+
+
+PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
+PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
+PQmakeEmptyPGresult.restype = PGresult_ptr
+
+PQsetResultAttrs = pq.PQsetResultAttrs
+PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
+PQsetResultAttrs.restype = c_int
+
+
+# 33.12. Notice Processing
+
+PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
+
+PQsetNoticeReceiver = pq.PQsetNoticeReceiver
+PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
+PQsetNoticeReceiver.restype = PQnoticeReceiver
+
+# 34.5 Pipeline Mode
+
+_PQpipelineStatus = None
+_PQenterPipelineMode = None
+_PQexitPipelineMode = None
+_PQpipelineSync = None
+_PQsendFlushRequest = None
+
+if libpq_version >= 140000:
+ _PQpipelineStatus = pq.PQpipelineStatus
+ _PQpipelineStatus.argtypes = [PGconn_ptr]
+ _PQpipelineStatus.restype = c_int
+
+ _PQenterPipelineMode = pq.PQenterPipelineMode
+ _PQenterPipelineMode.argtypes = [PGconn_ptr]
+ _PQenterPipelineMode.restype = c_int
+
+ _PQexitPipelineMode = pq.PQexitPipelineMode
+ _PQexitPipelineMode.argtypes = [PGconn_ptr]
+ _PQexitPipelineMode.restype = c_int
+
+ _PQpipelineSync = pq.PQpipelineSync
+ _PQpipelineSync.argtypes = [PGconn_ptr]
+ _PQpipelineSync.restype = c_int
+
+ _PQsendFlushRequest = pq.PQsendFlushRequest
+ _PQsendFlushRequest.argtypes = [PGconn_ptr]
+ _PQsendFlushRequest.restype = c_int
+
+
+def PQpipelineStatus(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineStatus:
+ raise NotSupportedError(
+ "PQpipelineStatus requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineStatus(pgconn)
+
+
+def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQenterPipelineMode:
+ raise NotSupportedError(
+ "PQenterPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQenterPipelineMode(pgconn)
+
+
+def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
+ if not _PQexitPipelineMode:
+ raise NotSupportedError(
+ "PQexitPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQexitPipelineMode(pgconn)
+
+
+def PQpipelineSync(pgconn: PGconn_struct) -> int:
+ if not _PQpipelineSync:
+ raise NotSupportedError(
+ "PQpipelineSync requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQpipelineSync(pgconn)
+
+
+def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
+ if not _PQsendFlushRequest:
+ raise NotSupportedError(
+ "PQsendFlushRequest requires libpq from PostgreSQL 14,"
+ f" {libpq_version} available instead"
+ )
+ return _PQsendFlushRequest(pgconn)
+
+
+# 33.18. SSL Support
+
+PQinitOpenSSL = pq.PQinitOpenSSL
+PQinitOpenSSL.argtypes = [c_int, c_int]
+PQinitOpenSSL.restype = None
+
+
+def generate_stub() -> None:
+ import re
+ from ctypes import _CFuncPtr # type: ignore
+
+ def type2str(fname, narg, t):
+ if t is None:
+ return "None"
+ elif t is c_void_p:
+ return "Any"
+ elif t is c_int or t is c_uint or t is c_size_t:
+ return "int"
+ elif t is c_char_p or t.__name__ == "LP_c_char":
+ if narg is not None:
+ return "bytes"
+ else:
+ return "Optional[bytes]"
+
+ elif t.__name__ in (
+ "LP_PGconn_struct",
+ "LP_PGresult_struct",
+ "LP_PGcancel_struct",
+ ):
+ if narg is not None:
+ return f"Optional[{t.__name__[3:]}]"
+ else:
+ return t.__name__[3:]
+
+ elif t.__name__ in ("LP_PQconninfoOption_struct",):
+ return f"Sequence[{t.__name__[3:]}]"
+
+ elif t.__name__ in (
+ "LP_c_ubyte",
+ "LP_c_char_p",
+ "LP_c_int",
+ "LP_c_uint",
+ "LP_c_ulong",
+ "LP_FILE",
+ ):
+ return f"_Pointer[{t.__name__[3:]}]"
+
+ else:
+ assert False, f"can't deal with {t} in {fname}"
+
+ fn = __file__ + "i"
+ with open(fn) as f:
+ lines = f.read().splitlines()
+
+ istart, iend = (
+ i
+ for i, line in enumerate(lines)
+ if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
+ )
+
+ known = {
+ line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
+ }
+
+ signatures = []
+
+ for name, obj in globals().items():
+ if name in known:
+ continue
+ if not isinstance(obj, _CFuncPtr):
+ continue
+
+ params = []
+ for i, t in enumerate(obj.argtypes):
+ params.append(f"arg{i + 1}: {type2str(name, i, t)}")
+
+ resname = type2str(name, None, obj.restype)
+
+ signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
+
+ lines[istart + 1 : iend] = signatures
+
+ with open(fn, "w") as f:
+ f.write("\n".join(lines))
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ generate_stub()
diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi
new file mode 100644
index 0000000..5d2ee3f
--- /dev/null
+++ b/psycopg/psycopg/pq/_pq_ctypes.pyi
@@ -0,0 +1,216 @@
+"""
+types stub for ctypes functions
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Optional, Sequence
+from ctypes import Array, pointer, _Pointer
+from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
+
+class FILE: ...
+
+def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var]
+
+Oid = c_uint
+
+class PGconn_struct: ...
+class PGresult_struct: ...
+class PGcancel_struct: ...
+
+class PQconninfoOption_struct:
+ keyword: bytes
+ envvar: bytes
+ compiled: bytes
+ val: bytes
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+class PGnotify_struct:
+ be_pid: int
+ relname: bytes
+ extra: bytes
+
+class PGresAttDesc_struct:
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQexecPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> PGresult_struct: ...
+def PQprepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> PGresult_struct: ...
+def PQgetvalue(
+ arg1: Optional[PGresult_struct], arg2: int, arg3: int
+) -> _Pointer[c_char]: ...
+def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQescapeStringConn(
+ arg1: Optional[PGconn_struct],
+ arg2: c_char_p,
+ arg3: bytes,
+ arg4: int,
+ arg5: _Pointer[c_int],
+) -> int: ...
+def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ...
+def PQsendPrepare(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: int,
+ arg5: Optional[Array[c_uint]],
+) -> int: ...
+def PQsendQueryPrepared(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: int,
+ arg4: Optional[Array[c_char_p]],
+ arg5: Optional[Array[c_int]],
+ arg6: Optional[Array[c_int]],
+ arg7: int,
+) -> int: ...
+def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
+def PQsetNoticeReceiver(
+ arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
+) -> Callable[[Any], PGresult_struct]: ...
+
+# TODO: Ignoring type as getting an error on mypy/ctypes:
+# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be
+# a subtype of "ctypes._CData"
+def PQnotifies(
+ arg1: Optional[PGconn_struct],
+) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore
+def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
+
+# Arg 2 is a _Pointer, reported as _CArgObject by mypy
+def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
+def PQsetResultAttrs(
+ arg1: Optional[PGresult_struct],
+ arg2: int,
+ arg3: Array[PGresAttDesc_struct], # type: ignore
+) -> int: ...
+def PQtrace(
+ arg1: Optional[PGconn_struct],
+ arg2: _Pointer[FILE], # type: ignore[type-var]
+) -> None: ...
+def PQencryptPasswordConn(
+ arg1: Optional[PGconn_struct],
+ arg2: bytes,
+ arg3: bytes,
+ arg4: Optional[bytes],
+) -> bytes: ...
+def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ...
+def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ...
+
+# fmt: off
+# autogenerated: start
+def PQlibVersion() -> int: ...
+def PQconnectdb(arg1: bytes) -> PGconn_struct: ...
+def PQconnectStart(arg1: bytes) -> PGconn_struct: ...
+def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ...
+def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ...
+def PQfinish(arg1: Optional[PGconn_struct]) -> None: ...
+def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
+def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
+def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQping(arg1: bytes) -> int: ...
+def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
+def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ...
+def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
+def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ...
+def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ...
+def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
+def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
+def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
+def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
+def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
+def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
+def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ...
+def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
+def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ...
+def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
+def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
+def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
+def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
+def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
+def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
+def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ...
+def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ...
+def PQfreemem(arg1: Any) -> None: ...
+def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ...
+def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
+def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ...
+def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ...
+def PQinitOpenSSL(arg1: int, arg2: int) -> None: ...
+# autogenerated: end
+# fmt: on
+
+# vim: set syntax=python:
diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py
new file mode 100644
index 0000000..9c45f64
--- /dev/null
+++ b/psycopg/psycopg/pq/abc.py
@@ -0,0 +1,385 @@
+"""
+Protocol objects to represent objects exposed by different pq implementations.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from ._enums import Format, Trace
+from .._compat import Protocol
+
+if TYPE_CHECKING:
+ from .misc import PGnotify, ConninfoOption, PGresAttDesc
+
+# An object implementing the buffer protocol (ish)
+Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
+
+
+class PGconn(Protocol):
+
+ notice_handler: Optional[Callable[["PGresult"], None]]
+ notify_handler: Optional[Callable[["PGnotify"], None]]
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ ...
+
+ def connect_poll(self) -> int:
+ ...
+
+ def finish(self) -> None:
+ ...
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ ...
+
+ def reset(self) -> None:
+ ...
+
+ def reset_start(self) -> None:
+ ...
+
+ def reset_poll(self) -> int:
+ ...
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ ...
+
+ @property
+ def db(self) -> bytes:
+ ...
+
+ @property
+ def user(self) -> bytes:
+ ...
+
+ @property
+ def password(self) -> bytes:
+ ...
+
+ @property
+ def host(self) -> bytes:
+ ...
+
+ @property
+ def hostaddr(self) -> bytes:
+ ...
+
+ @property
+ def port(self) -> bytes:
+ ...
+
+ @property
+ def tty(self) -> bytes:
+ ...
+
+ @property
+ def options(self) -> bytes:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def transaction_status(self) -> int:
+ ...
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ @property
+ def server_version(self) -> int:
+ ...
+
+ @property
+ def socket(self) -> int:
+ ...
+
+ @property
+ def backend_pid(self) -> int:
+ ...
+
+ @property
+ def needs_password(self) -> bool:
+ ...
+
+ @property
+ def used_password(self) -> bool:
+ ...
+
+ @property
+ def ssl_in_use(self) -> bool:
+ ...
+
+ def exec_(self, command: bytes) -> "PGresult":
+ ...
+
+ def send_query(self, command: bytes) -> None:
+ ...
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ ...
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ ...
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional[Buffer]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ ...
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ ...
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Buffer]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ ...
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ ...
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ ...
+
+ def send_describe_portal(self, name: bytes) -> None:
+ ...
+
+ def get_result(self) -> Optional["PGresult"]:
+ ...
+
+ def consume_input(self) -> None:
+ ...
+
+ def is_busy(self) -> int:
+ ...
+
+ @property
+ def nonblocking(self) -> int:
+ ...
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ ...
+
+ def flush(self) -> int:
+ ...
+
+ def set_single_row_mode(self) -> None:
+ ...
+
+ def get_cancel(self) -> "PGcancel":
+ ...
+
+ def notifies(self) -> Optional["PGnotify"]:
+ ...
+
+ def put_copy_data(self, buffer: Buffer) -> int:
+ ...
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ ...
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ ...
+
+ def trace(self, fileno: int) -> None:
+ ...
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ ...
+
+ def untrace(self) -> None:
+ ...
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ ...
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ ...
+
+ @property
+ def pipeline_status(self) -> int:
+ ...
+
+ def enter_pipeline_mode(self) -> None:
+ ...
+
+ def exit_pipeline_mode(self) -> None:
+ ...
+
+ def pipeline_sync(self) -> None:
+ ...
+
+ def send_flush_request(self) -> None:
+ ...
+
+
+class PGresult(Protocol):
+ def clear(self) -> None:
+ ...
+
+ @property
+ def status(self) -> int:
+ ...
+
+ @property
+ def error_message(self) -> bytes:
+ ...
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def ntuples(self) -> int:
+ ...
+
+ @property
+ def nfields(self) -> int:
+ ...
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ ...
+
+ def ftable(self, column_number: int) -> int:
+ ...
+
+ def ftablecol(self, column_number: int) -> int:
+ ...
+
+ def fformat(self, column_number: int) -> int:
+ ...
+
+ def ftype(self, column_number: int) -> int:
+ ...
+
+ def fmod(self, column_number: int) -> int:
+ ...
+
+ def fsize(self, column_number: int) -> int:
+ ...
+
+ @property
+ def binary_tuples(self) -> int:
+ ...
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ ...
+
+ @property
+ def nparams(self) -> int:
+ ...
+
+ def param_type(self, param_number: int) -> int:
+ ...
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ ...
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ ...
+
+ @property
+ def oid_value(self) -> int:
+ ...
+
+ def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
+ ...
+
+
+class PGcancel(Protocol):
+ def free(self) -> None:
+ ...
+
+ def cancel(self) -> None:
+ ...
+
+
+class Conninfo(Protocol):
+ @classmethod
+ def get_defaults(cls) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
+ ...
+
+ @classmethod
+ def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
+ ...
+
+
+class Escaping(Protocol):
+ def __init__(self, conn: Optional[PGconn] = None):
+ ...
+
+ def escape_literal(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_identifier(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_string(self, data: Buffer) -> bytes:
+ ...
+
+ def escape_bytea(self, data: Buffer) -> bytes:
+ ...
+
+ def unescape_bytea(self, data: Buffer) -> bytes:
+ ...
diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py
new file mode 100644
index 0000000..3a43133
--- /dev/null
+++ b/psycopg/psycopg/pq/misc.py
@@ -0,0 +1,146 @@
+"""
+Various functionalities to make easier to work with the libpq.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import sys
+import logging
+import ctypes.util
+from typing import cast, NamedTuple, Optional, Union
+
+from .abc import PGconn, PGresult
+from ._enums import ConnStatus, TransactionStatus, PipelineStatus
+from .._compat import cache
+from .._encodings import pgconn_encoding
+
+logger = logging.getLogger("psycopg.pq")
+
+OK = ConnStatus.OK
+
+
+class PGnotify(NamedTuple):
+ relname: bytes
+ be_pid: int
+ extra: bytes
+
+
+class ConninfoOption(NamedTuple):
+ keyword: bytes
+ envvar: Optional[bytes]
+ compiled: Optional[bytes]
+ val: Optional[bytes]
+ label: bytes
+ dispchar: bytes
+ dispsize: int
+
+
+class PGresAttDesc(NamedTuple):
+ name: bytes
+ tableid: int
+ columnid: int
+ format: int
+ typid: int
+ typlen: int
+ atttypmod: int
+
+
+@cache
+def find_libpq_full_path() -> Optional[str]:
+ if sys.platform == "win32":
+ libname = ctypes.util.find_library("libpq.dll")
+
+ elif sys.platform == "darwin":
+ libname = ctypes.util.find_library("libpq.dylib")
+ # (hopefully) temporary hack: libpq not in a standard place
+ # https://github.com/orgs/Homebrew/discussions/3595
+ # If pg_config is available and agrees, let's use its indications.
+ if not libname:
+ try:
+ import subprocess as sp
+
+ libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
+ libname = os.path.join(libdir, "libpq.dylib")
+ if not os.path.exists(libname):
+ libname = None
+ except Exception as ex:
+ logger.debug("couldn't use pg_config to find libpq: %s", ex)
+
+ else:
+ libname = ctypes.util.find_library("pq")
+
+ return libname
+
+
+def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
+ """
+ Return an error message from a `PGconn` or `PGresult`.
+
+ The return value is a `!str` (unlike pq data which is usually `!bytes`):
+ use the connection encoding if available, otherwise the `!encoding`
+ parameter as a fallback for decoding. Don't raise exceptions on decoding
+ errors.
+
+ """
+ bmsg: bytes
+
+ if hasattr(obj, "error_field"):
+ # obj is a PGresult
+ obj = cast(PGresult, obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ elif hasattr(obj, "error_message"):
+ # obj is a PGconn
+ if obj.status == OK:
+ encoding = pgconn_encoding(obj)
+ bmsg = obj.error_message
+
+ # strip severity and whitespaces
+ if bmsg:
+ bmsg = bmsg.split(b":", 1)[-1].strip()
+
+ else:
+ raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
+
+ if bmsg:
+ msg = bmsg.decode(encoding, "replace")
+ else:
+ msg = "no details available"
+
+ return msg
+
+
+def connection_summary(pgconn: PGconn) -> str:
+ """
+ Return summary information on a connection.
+
+ Useful for __repr__
+ """
+ parts = []
+ if pgconn.status == OK:
+ # Put together the [STATUS]
+ status = TransactionStatus(pgconn.transaction_status).name
+ if pgconn.pipeline_status:
+ status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
+
+ # Put together the (CONNECTION)
+ if not pgconn.host.startswith(b"/"):
+ parts.append(("host", pgconn.host.decode()))
+ if pgconn.port != b"5432":
+ parts.append(("port", pgconn.port.decode()))
+ if pgconn.user != pgconn.db:
+ parts.append(("user", pgconn.user.decode()))
+ parts.append(("database", pgconn.db.decode()))
+
+ else:
+ status = ConnStatus(pgconn.status).name
+
+ sparts = " ".join("%s=%s" % part for part in parts)
+ if sparts:
+ sparts = f" ({sparts})"
+ return f"[{status}]{sparts}"
diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py
new file mode 100644
index 0000000..8b87c19
--- /dev/null
+++ b/psycopg/psycopg/pq/pq_ctypes.py
@@ -0,0 +1,1086 @@
+"""
+libpq Python wrapper using ctypes bindings.
+
+Clients shouldn't use this module directly, unless for testing: they should use
+the `pq` module instead, which is in charge of choosing the best
+implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+import logging
+from os import getpid
+from weakref import ref
+
+from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref
+from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+from typing import cast as t_cast, TYPE_CHECKING
+
+from .. import errors as e
+from . import _pq_ctypes as impl
+from .misc import PGnotify, ConninfoOption, PGresAttDesc
+from .misc import error_message, connection_summary
+from ._enums import Format, ExecStatus, Trace
+
+# Imported locally to call them from __del__ methods
+from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus
+
+if TYPE_CHECKING:
+ from . import abc
+
+__impl__ = "python"
+
+logger = logging.getLogger("psycopg")
+
+
+def version() -> int:
+ """Return the version number of the libpq currently loaded.
+
+ The number is in the same format of `~psycopg.ConnectionInfo.server_version`.
+
+ Certain features might not be available if the libpq library used is too old.
+ """
+ return impl.PQlibVersion()
+
+
+@impl.PQnoticeReceiver # type: ignore
+def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
+ pgconn = cast(arg, POINTER(py_object)).contents.value()
+ if not (pgconn and pgconn.notice_handler):
+ return
+
+ res = PGresult(result_ptr)
+ try:
+ pgconn.notice_handler(res)
+ except Exception as exc:
+ logger.exception("error in notice receiver: %s", exc)
+ finally:
+ res._pgresult_ptr = None # avoid destroying the pgresult_ptr
+
+
+class PGconn:
+ """
+ Python representation of a libpq connection.
+ """
+
+ __slots__ = (
+ "_pgconn_ptr",
+ "notice_handler",
+ "notify_handler",
+ "_self_ptr",
+ "_procpid",
+ "__weakref__",
+ )
+
+ def __init__(self, pgconn_ptr: impl.PGconn_struct):
+ self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
+ self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None
+ self.notify_handler: Optional[Callable[[PGnotify], None]] = None
+
+ # Keep alive for the lifetime of PGconn
+ self._self_ptr = py_object(ref(self))
+ impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr))
+
+ self._procpid = getpid()
+
+ def __del__(self) -> None:
+ # Close the connection only if it was created in this process,
+ # not if this object is being GC'd after fork.
+ if getpid() == self._procpid:
+ self.finish()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @classmethod
+ def connect(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectdb(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ @classmethod
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ pgconn_ptr = impl.PQconnectStart(conninfo)
+ if not pgconn_ptr:
+ raise MemoryError("couldn't allocate PGconn")
+ return cls(pgconn_ptr)
+
+ def connect_poll(self) -> int:
+ return self._call_int(impl.PQconnectPoll)
+
+ def finish(self) -> None:
+ self._pgconn_ptr, p = None, self._pgconn_ptr
+ if p:
+ PQfinish(p)
+
+ @property
+ def pgconn_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGconn` structure, as integer.
+
+ `!None` if the connection is closed.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgconn_ptr is None:
+ return None
+
+ return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ self._ensure_pgconn()
+ opts = impl.PQconninfo(self._pgconn_ptr)
+ if not opts:
+ raise MemoryError("couldn't allocate connection info")
+ try:
+ return Conninfo._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ def reset(self) -> None:
+ self._ensure_pgconn()
+ impl.PQreset(self._pgconn_ptr)
+
+ def reset_start(self) -> None:
+ if not impl.PQresetStart(self._pgconn_ptr):
+ raise e.OperationalError("couldn't reset connection")
+
+ def reset_poll(self) -> int:
+ return self._call_int(impl.PQresetPoll)
+
+ @classmethod
+ def ping(self, conninfo: bytes) -> int:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ return impl.PQping(conninfo)
+
+ @property
+ def db(self) -> bytes:
+ return self._call_bytes(impl.PQdb)
+
+ @property
+ def user(self) -> bytes:
+ return self._call_bytes(impl.PQuser)
+
+ @property
+ def password(self) -> bytes:
+ return self._call_bytes(impl.PQpass)
+
+ @property
+ def host(self) -> bytes:
+ return self._call_bytes(impl.PQhost)
+
+ @property
+ def hostaddr(self) -> bytes:
+ return self._call_bytes(impl.PQhostaddr)
+
+ @property
+ def port(self) -> bytes:
+ return self._call_bytes(impl.PQport)
+
+ @property
+ def tty(self) -> bytes:
+ return self._call_bytes(impl.PQtty)
+
+ @property
+ def options(self) -> bytes:
+ return self._call_bytes(impl.PQoptions)
+
+ @property
+ def status(self) -> int:
+ return PQstatus(self._pgconn_ptr)
+
+ @property
+ def transaction_status(self) -> int:
+ return impl.PQtransactionStatus(self._pgconn_ptr)
+
+ def parameter_status(self, name: bytes) -> Optional[bytes]:
+ self._ensure_pgconn()
+ return impl.PQparameterStatus(self._pgconn_ptr, name)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQerrorMessage(self._pgconn_ptr)
+
+ @property
+ def protocol_version(self) -> int:
+ return self._call_int(impl.PQprotocolVersion)
+
+ @property
+ def server_version(self) -> int:
+ return self._call_int(impl.PQserverVersion)
+
+ @property
+ def socket(self) -> int:
+ rv = self._call_int(impl.PQsocket)
+ if rv == -1:
+ raise e.OperationalError("the connection is lost")
+ return rv
+
+ @property
+ def backend_pid(self) -> int:
+ return self._call_int(impl.PQbackendPID)
+
+ @property
+ def needs_password(self) -> bool:
+ """True if the connection authentication method required a password,
+ but none was available.
+
+ See :pq:`PQconnectionNeedsPassword` for details.
+ """
+ return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr))
+
+ @property
+ def used_password(self) -> bool:
+ """True if the connection authentication method used a password.
+
+ See :pq:`PQconnectionUsedPassword` for details.
+ """
+ return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr))
+
+ @property
+ def ssl_in_use(self) -> bool:
+ return self._call_bool(impl.PQsslInUse)
+
+ def exec_(self, command: bytes) -> "PGresult":
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQexec(self._pgconn_ptr, command)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query(self, command: bytes) -> None:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendQuery(self._pgconn_ptr, command):
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
+
+ def exec_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> "PGresult":
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ rv = impl.PQexecParams(*args)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_query_params(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ args = self._query_params_args(
+ command, param_values, param_types, param_formats, result_format
+ )
+ self._ensure_pgconn()
+ if not impl.PQsendQueryParams(*args):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes):
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_query_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> None:
+ # repurpose this function with a cheeky replacement of query with name,
+ # drop the param_types from the result
+ args = self._query_params_args(
+ name, param_values, None, param_formats, result_format
+ )
+ args = args[:3] + args[4:]
+
+ self._ensure_pgconn()
+ if not impl.PQsendQueryPrepared(*args):
+ raise e.OperationalError(
+ f"sending prepared query failed: {error_message(self)}"
+ )
+
+ def _query_params_args(
+ self,
+ command: bytes,
+ param_values: Optional[Sequence[Optional["abc.Buffer"]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = Format.TEXT,
+ ) -> Any:
+ if not isinstance(command, bytes):
+ raise TypeError(f"bytes expected, got {type(command)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ atypes: Optional[Array[impl.Oid]]
+ if not param_types:
+ atypes = None
+ else:
+ if len(param_types) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_types))
+ )
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_formats"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ return (
+ self._pgconn_ptr,
+ command,
+ nparams,
+ atypes,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ if not isinstance(command, bytes):
+ raise TypeError(f"'command' must be bytes, got {type(command)} instead")
+
+ if not param_types:
+ nparams = 0
+ atypes = None
+ else:
+ nparams = len(param_types)
+ atypes = (impl.Oid * nparams)(*param_types)
+
+ self._ensure_pgconn()
+ rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def exec_prepared(
+ self,
+ name: bytes,
+ param_values: Optional[Sequence["abc.Buffer"]],
+ param_formats: Optional[Sequence[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+
+ aparams: Optional[Array[c_char_p]]
+ alenghts: Optional[Array[c_int]]
+ if param_values:
+ nparams = len(param_values)
+ aparams = (c_char_p * nparams)(
+ *(
+ # convert bytearray/memoryview to bytes
+ b if b is None or isinstance(b, bytes) else bytes(b)
+ for b in param_values
+ )
+ )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
+ else:
+ nparams = 0
+ aparams = alenghts = None
+
+ if not param_formats:
+ aformats = None
+ else:
+ if len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(param_formats))
+ )
+ aformats = (c_int * nparams)(*param_formats)
+
+ self._ensure_pgconn()
+ rv = impl.PQexecPrepared(
+ self._pgconn_ptr,
+ name,
+ nparams,
+ aparams,
+ alenghts,
+ aformats,
+ result_format,
+ )
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def describe_prepared(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePrepared(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_prepared(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePrepared(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def describe_portal(self, name: bytes) -> "PGresult":
+ if not isinstance(name, bytes):
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
+ self._ensure_pgconn()
+ rv = impl.PQdescribePortal(self._pgconn_ptr, name)
+ if not rv:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult(rv)
+
+ def send_describe_portal(self, name: bytes) -> None:
+ if not isinstance(name, bytes):
+ raise TypeError(f"bytes expected, got {type(name)} instead")
+ self._ensure_pgconn()
+ if not impl.PQsendDescribePortal(self._pgconn_ptr, name):
+ raise e.OperationalError(
+ f"sending describe portal failed: {error_message(self)}"
+ )
+
+ def get_result(self) -> Optional["PGresult"]:
+ rv = impl.PQgetResult(self._pgconn_ptr)
+ return PGresult(rv) if rv else None
+
+ def consume_input(self) -> None:
+ if 1 != impl.PQconsumeInput(self._pgconn_ptr):
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
+
+ def is_busy(self) -> int:
+ return impl.PQisBusy(self._pgconn_ptr)
+
+ @property
+ def nonblocking(self) -> int:
+ return impl.PQisnonblocking(self._pgconn_ptr)
+
+ @nonblocking.setter
+ def nonblocking(self, arg: int) -> None:
+ if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg):
+ raise e.OperationalError(
+ f"setting nonblocking failed: {error_message(self)}"
+ )
+
+ def flush(self) -> int:
+ # PQflush segfaults if it receives a NULL connection
+ if not self._pgconn_ptr:
+ raise e.OperationalError("flushing failed: the connection is closed")
+ rv: int = impl.PQflush(self._pgconn_ptr)
+ if rv < 0:
+ raise e.OperationalError(f"flushing failed: {error_message(self)}")
+ return rv
+
+ def set_single_row_mode(self) -> None:
+ if not impl.PQsetSingleRowMode(self._pgconn_ptr):
+ raise e.OperationalError("setting single row mode failed")
+
+ def get_cancel(self) -> "PGcancel":
+ """
+ Create an object with the information needed to cancel a command.
+
+ See :pq:`PQgetCancel` for details.
+ """
+ rv = impl.PQgetCancel(self._pgconn_ptr)
+ if not rv:
+ raise e.OperationalError("couldn't create cancel object")
+ return PGcancel(rv)
+
+ def notifies(self) -> Optional[PGnotify]:
+ ptr = impl.PQnotifies(self._pgconn_ptr)
+ if ptr:
+ c = ptr.contents
+ return PGnotify(c.relname, c.be_pid, c.extra)
+ impl.PQfreemem(ptr)
+ else:
+ return None
+
+ def put_copy_data(self, buffer: "abc.Buffer") -> int:
+ if not isinstance(buffer, bytes):
+ buffer = bytes(buffer)
+ rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
+ if rv < 0:
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
+ return rv
+
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
+ buffer_ptr = c_char_p()
+ nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_)
+ if nbytes == -2:
+ raise e.OperationalError(
+ f"receiving copy data failed: {error_message(self)}"
+ )
+ if buffer_ptr:
+ # TODO: do it without copy
+ data = string_at(buffer_ptr, nbytes)
+ impl.PQfreemem(buffer_ptr)
+ return nbytes, memoryview(data)
+ else:
+ return nbytes, memoryview(b"")
+
+ def trace(self, fileno: int) -> None:
+ """
+ Enable tracing of the client/server communication to a file stream.
+
+ See :pq:`PQtrace` for details.
+ """
+ if sys.platform != "linux":
+ raise e.NotSupportedError("currently only supported on Linux")
+ stream = impl.fdopen(fileno, b"w")
+ impl.PQtrace(self._pgconn_ptr, stream)
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ """
+ Configure tracing behavior of client/server communication.
+
+ :param flags: operating mode of tracing.
+
+ See :pq:`PQsetTraceFlags` for details.
+ """
+ impl.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+ def untrace(self) -> None:
+ """
+ Disable tracing, previously enabled through `trace()`.
+
+ See :pq:`PQuntrace` for details.
+ """
+ impl.PQuntrace(self._pgconn_ptr)
+
+ def encrypt_password(
+ self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
+ ) -> bytes:
+ """
+ Return the encrypted form of a PostgreSQL password.
+
+ See :pq:`PQencryptPasswordConn` for details.
+ """
+ out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm)
+ if not out:
+ raise e.OperationalError(
+ f"password encryption failed: {error_message(self)}"
+ )
+
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def make_empty_result(self, exec_status: int) -> "PGresult":
+ rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)
+ if not rv:
+ raise MemoryError("couldn't allocate empty PGresult")
+ return PGresult(rv)
+
+ @property
+ def pipeline_status(self) -> int:
+ if version() < 140000:
+ return 0
+ return impl.PQpipelineStatus(self._pgconn_ptr)
+
+ def enter_pipeline_mode(self) -> None:
+ """Enter pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to enter the pipeline
+ mode.
+ """
+ if impl.PQenterPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError("failed to enter pipeline mode")
+
+ def exit_pipeline_mode(self) -> None:
+ """Exit pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to exit the pipeline
+ mode.
+ """
+ if impl.PQexitPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError(error_message(self))
+
+ def pipeline_sync(self) -> None:
+ """Mark a synchronization point in a pipeline.
+
+ :raises ~e.OperationalError: if the connection is not in pipeline mode
+ or if sync failed.
+ """
+ rv = impl.PQpipelineSync(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError("connection not in pipeline mode")
+ if rv != 1:
+ raise e.OperationalError("failed to sync pipeline")
+
+ def send_flush_request(self) -> None:
+ """Sends a request for the server to flush its output buffer.
+
+ :raises ~e.OperationalError: if the flush request failed.
+ """
+ if impl.PQsendFlushRequest(self._pgconn_ptr) == 0:
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
+
+ def _call_bytes(
+ self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
+ ) -> bytes:
+ """
+ Call one of the pgconn libpq functions returning a bytes pointer.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ rv = func(self._pgconn_ptr)
+ assert rv is not None
+ return rv
+
+ def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int:
+ """
+ Call one of the pgconn libpq functions returning an int.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return func(self._pgconn_ptr)
+
+ def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool:
+ """
+ Call one of the pgconn libpq functions returning a logical value.
+ """
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+ return bool(func(self._pgconn_ptr))
+
+ def _ensure_pgconn(self) -> None:
+ if not self._pgconn_ptr:
+ raise e.OperationalError("the connection is closed")
+
+
+class PGresult:
+ """
+ Python representation of a libpq result.
+ """
+
+ __slots__ = ("_pgresult_ptr",)
+
+ def __init__(self, pgresult_ptr: impl.PGresult_struct):
+ self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr
+
+ def __del__(self) -> None:
+ self.clear()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ status = ExecStatus(self.status)
+ return f"<{cls} [{status.name}] at 0x{id(self):x}>"
+
+ def clear(self) -> None:
+ self._pgresult_ptr, p = None, self._pgresult_ptr
+ if p:
+ PQclear(p)
+
+ @property
+ def pgresult_ptr(self) -> Optional[int]:
+ """The pointer to the underlying `!PGresult` structure, as integer.
+
+ `!None` if the result was cleared.
+
+ The value can be used to pass the structure to libpq functions which
+ psycopg doesn't (currently) wrap, either in C or in Python using FFI
+ libraries such as `ctypes`.
+ """
+ if self._pgresult_ptr is None:
+ return None
+
+ return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined]
+
+ @property
+ def status(self) -> int:
+ return impl.PQresultStatus(self._pgresult_ptr)
+
+ @property
+ def error_message(self) -> bytes:
+ return impl.PQresultErrorMessage(self._pgresult_ptr)
+
+ def error_field(self, fieldcode: int) -> Optional[bytes]:
+ return impl.PQresultErrorField(self._pgresult_ptr, fieldcode)
+
+ @property
+ def ntuples(self) -> int:
+ return impl.PQntuples(self._pgresult_ptr)
+
+ @property
+ def nfields(self) -> int:
+ return impl.PQnfields(self._pgresult_ptr)
+
+ def fname(self, column_number: int) -> Optional[bytes]:
+ return impl.PQfname(self._pgresult_ptr, column_number)
+
+ def ftable(self, column_number: int) -> int:
+ return impl.PQftable(self._pgresult_ptr, column_number)
+
+ def ftablecol(self, column_number: int) -> int:
+ return impl.PQftablecol(self._pgresult_ptr, column_number)
+
+ def fformat(self, column_number: int) -> int:
+ return impl.PQfformat(self._pgresult_ptr, column_number)
+
+ def ftype(self, column_number: int) -> int:
+ return impl.PQftype(self._pgresult_ptr, column_number)
+
+ def fmod(self, column_number: int) -> int:
+ return impl.PQfmod(self._pgresult_ptr, column_number)
+
+ def fsize(self, column_number: int) -> int:
+ return impl.PQfsize(self._pgresult_ptr, column_number)
+
+ @property
+ def binary_tuples(self) -> int:
+ return impl.PQbinaryTuples(self._pgresult_ptr)
+
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number)
+ if length:
+ v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
+ return string_at(v, length)
+ else:
+ if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
+ return None
+ else:
+ return b""
+
+ @property
+ def nparams(self) -> int:
+ return impl.PQnparams(self._pgresult_ptr)
+
+ def param_type(self, param_number: int) -> int:
+ return impl.PQparamtype(self._pgresult_ptr, param_number)
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ return impl.PQcmdStatus(self._pgresult_ptr)
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ rv = impl.PQcmdTuples(self._pgresult_ptr)
+ return int(rv) if rv else None
+
+ @property
+ def oid_value(self) -> int:
+ return impl.PQoidValue(self._pgresult_ptr)
+
+ def set_attributes(self, descriptions: List[PGresAttDesc]) -> None:
+ structs = [
+ impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore
+ ]
+ array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore
+ rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
+ if rv == 0:
+ raise e.OperationalError("PQsetResultAttrs failed")
+
+
+class PGcancel:
+ """
+ Token to cancel the current operation on a connection.
+
+ Created by `PGconn.get_cancel()`.
+ """
+
+ __slots__ = ("pgcancel_ptr",)
+
+ def __init__(self, pgcancel_ptr: impl.PGcancel_struct):
+ self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr
+
+ def __del__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ """
+ Free the data structure created by :pq:`PQgetCancel()`.
+
+ Automatically invoked by `!__del__()`.
+
+ See :pq:`PQfreeCancel()` for details.
+ """
+ self.pgcancel_ptr, p = None, self.pgcancel_ptr
+ if p:
+ PQfreeCancel(p)
+
+ def cancel(self) -> None:
+ """Requests that the server abandon processing of the current command.
+
+ See :pq:`PQcancel()` for details.
+ """
+ buf = create_string_buffer(256)
+ res = impl.PQcancel(
+ self.pgcancel_ptr,
+ byref(buf), # type: ignore[arg-type]
+ len(buf),
+ )
+ if not res:
+ raise e.OperationalError(
+ f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
+ )
+
+
+class Conninfo:
+ """
+ Utility object to manipulate connection strings.
+ """
+
+ @classmethod
+ def get_defaults(cls) -> List[ConninfoOption]:
+ opts = impl.PQconndefaults()
+ if not opts:
+ raise MemoryError("couldn't allocate connection defaults")
+ try:
+ return cls._options_from_array(opts)
+ finally:
+ impl.PQconninfoFree(opts)
+
+ @classmethod
+ def parse(cls, conninfo: bytes) -> List[ConninfoOption]:
+ if not isinstance(conninfo, bytes):
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
+
+ errmsg = c_char_p()
+ rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type]
+ if not rv:
+ if not errmsg:
+ raise MemoryError("couldn't allocate on conninfo parse")
+ else:
+ exc = e.OperationalError(
+ (errmsg.value or b"").decode("utf8", "replace")
+ )
+ impl.PQfreemem(errmsg)
+ raise exc
+
+ try:
+ return cls._options_from_array(rv)
+ finally:
+ impl.PQconninfoFree(rv)
+
+ @classmethod
+ def _options_from_array(
+ cls, opts: Sequence[impl.PQconninfoOption_struct]
+ ) -> List[ConninfoOption]:
+ rv = []
+ skws = "keyword envvar compiled val label dispchar".split()
+ for opt in opts:
+ if not opt.keyword:
+ break
+ d = {kw: getattr(opt, kw) for kw in skws}
+ d["dispsize"] = opt.dispsize
+ rv.append(ConninfoOption(**d))
+
+ return rv
+
+
+class Escaping:
+ """
+ Utility object to escape strings for SQL interpolation.
+ """
+
+ def __init__(self, conn: Optional[PGconn] = None):
+ self.conn = conn
+
+ def escape_literal(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_literal failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+ # TODO: might be done without copy (however C does that)
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_literal failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_identifier(self, data: "abc.Buffer") -> bytes:
+ if not self.conn:
+ raise e.OperationalError("escape_identifier failed: no connection provided")
+
+ self.conn._ensure_pgconn()
+
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))
+ if not out:
+ raise e.OperationalError(
+ f"escape_identifier failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ def escape_string(self, data: "abc.Buffer") -> bytes:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ if self.conn:
+ self.conn._ensure_pgconn()
+ error = c_int()
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeStringConn(
+ self.conn._pgconn_ptr,
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ byref(error), # type: ignore[arg-type]
+ )
+
+ if error:
+ raise e.OperationalError(
+ f"escape_string failed: {error_message(self.conn)} bytes"
+ )
+
+ else:
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeString(
+ byref(out), # type: ignore[arg-type]
+ data,
+ len(data),
+ )
+
+ return out.value
+
+ def escape_bytea(self, data: "abc.Buffer") -> bytes:
+ len_out = c_size_t()
+ # TODO: might be able to do without a copy but it's a mess.
+ # the C library does it better anyway, so maybe not worth optimising
+ # https://mail.python.org/pipermail/python-dev/2012-September/121780.html
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ if self.conn:
+ self.conn._ensure_pgconn()
+ out = impl.PQescapeByteaConn(
+ self.conn._pgconn_ptr,
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ else:
+ out = impl.PQescapeBytea(
+ data,
+ len(data),
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value - 1) # out includes final 0
+ impl.PQfreemem(out)
+ return rv
+
+ def unescape_bytea(self, data: "abc.Buffer") -> bytes:
+ # not needed, but let's keep it symmetric with the escaping:
+ # if a connection is passed in, it must be valid.
+ if self.conn:
+ self.conn._ensure_pgconn()
+
+ len_out = c_size_t()
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ out = impl.PQunescapeBytea(
+ data,
+ byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type]
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value)
+ impl.PQfreemem(out)
+ return rv
+
+
+# importing the ssl module sets up Python's libcrypto callbacks
+import ssl # noqa
+
+# disable libcrypto setup in libpq, so it won't stomp on the callbacks
+# that have already been set up
+impl.PQinitOpenSSL(1, 0)
+
+__build_version__ = version()
diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg/psycopg/py.typed
diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py
new file mode 100644
index 0000000..cb28b57
--- /dev/null
+++ b/psycopg/psycopg/rows.py
@@ -0,0 +1,256 @@
+"""
+psycopg row factories
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import functools
+from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
+from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
+from collections import namedtuple
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from ._compat import Protocol
+from ._encodings import _as_python_identifier
+
+if TYPE_CHECKING:
+ from .cursor import BaseCursor, Cursor
+ from .cursor_async import AsyncCursor
+ from psycopg.pq.abc import PGresult
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+
+T = TypeVar("T", covariant=True)
+
+# Row factories
+
+Row = TypeVar("Row", covariant=True)
+
+
+class RowMaker(Protocol[Row]):
+ """
+ Callable protocol taking a sequence of value and returning an object.
+
+ The sequence of value is what is returned from a database query, already
+ adapted to the right Python types. The return value is the object that your
+ program would like to receive: by default (`tuple_row()`) it is a simple
+ tuple, but it may be any type of object.
+
+ Typically, `!RowMaker` functions are returned by `RowFactory`.
+ """
+
+ def __call__(self, __values: Sequence[Any]) -> Row:
+ ...
+
+
+class RowFactory(Protocol[Row]):
+ """
+ Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
+
+ A `!RowFactory` is typically called when a `!Cursor` receives a result.
+ This way it can inspect the cursor state (for instance the
+ `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
+ a complete object.
+
+ For instance the `dict_row()` `!RowFactory` uses the names of the column to
+ define the dictionary key and returns a `!RowMaker` function which would
+ use the values to create a dictionary for each record.
+ """
+
+ def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class AsyncRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking an async cursor as argument.
+ """
+
+ def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
+ ...
+
+
+class BaseRowFactory(Protocol[Row]):
+ """
+ Like `RowFactory`, taking either type of cursor as argument.
+ """
+
+ def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
+ ...
+
+
+TupleRow: TypeAlias = Tuple[Any, ...]
+"""
+An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
+"""
+
+
+DictRow: TypeAlias = Dict[str, Any]
+"""
+An alias for the type returned by `dict_row()`
+
+A `!DictRow` is a dictionary with keys as string and any value returned by the
+database.
+"""
+
+
+def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
+ r"""Row factory to represent rows as simple tuples.
+
+ This is the default factory, used when `~psycopg.Connection.connect()` or
+ `~psycopg.Connection.cursor()` are called without a `!row_factory`
+ parameter.
+
+ """
+ # Implementation detail: make sure this is the tuple type itself, not an
+ # equivalent function, because the C code fast-paths on it.
+ return tuple
+
+
+def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
+ """Row factory to represent rows as dictionaries.
+
+ The dictionary keys are taken from the column names of the returned columns.
+ """
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
+ # https://github.com/python/mypy/issues/2608
+ return dict(zip(names, values)) # type: ignore[arg-type]
+
+ return dict_row_
+
+
+def namedtuple_row(
+ cursor: "BaseCursor[Any, Any]",
+) -> "RowMaker[NamedTuple]":
+ """Row factory to represent rows as `~collections.namedtuple`.
+
+ The field names are taken from the column names of the returned columns,
+ with some mangling to deal with invalid names.
+ """
+ res = cursor.pgresult
+ if not res:
+ return no_result
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return no_result
+
+ nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
+ return nt._make
+
+
+@functools.lru_cache(512)
+def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
+ snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
+ return namedtuple("Row", snames) # type: ignore[return-value]
+
+
+def class_row(cls: Type[T]) -> BaseRowFactory[T]:
+ r"""Generate a row factory to represent rows as instances of the class `!cls`.
+
+ The class must support every output column name as a keyword parameter.
+
+ :param cls: The class to return for each row. It must support the fields
+ returned by the query as keyword arguments.
+ :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
+ """
+
+ def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def class_row__(values: Sequence[Any]) -> T:
+ return cls(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return class_row__
+
+ return class_row_
+
+
+def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with positional parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as positional arguments.
+ """
+
+ def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ def args_row__(values: Sequence[Any]) -> T:
+ return func(*values)
+
+ return args_row__
+
+ return args_row_
+
+
+def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+ """Generate a row factory calling `!func` with keyword parameters for every row.
+
+ :param func: The function to call for each row. It must support the fields
+ returned by the query as keyword arguments.
+ """
+
+ def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
+ names = _get_names(cursor)
+ if names is None:
+ return no_result
+
+ def kwargs_row__(values: Sequence[Any]) -> T:
+ return func(**dict(zip(names, values))) # type: ignore[arg-type]
+
+ return kwargs_row__
+
+ return kwargs_row_
+
+
+def no_result(values: Sequence[Any]) -> NoReturn:
+ """A `RowMaker` that always fail.
+
+ It can be used as return value for a `RowFactory` called with no result.
+ Note that the `!RowFactory` *will* be called with no result, but the
+ resulting `!RowMaker` never should.
+ """
+ raise e.InterfaceError("the cursor doesn't have a result")
+
+
+def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
+ res = cursor.pgresult
+ if not res:
+ return None
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return None
+
+ enc = cursor._encoding
+ return [
+ res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
+ ]
+
+
+def _get_nfields(res: "PGresult") -> Optional[int]:
+ """
+ Return the number of columns in a result, if it returns tuples else None
+
+ Take into account the special case of results with zero columns.
+ """
+ nfields = res.nfields
+
+ if (
+ res.status == TUPLES_OK
+ or res.status == SINGLE_TUPLE
+ # "describe" in named cursors
+ or (res.status == COMMAND_OK and nfields)
+ ):
+ return nfields
+ else:
+ return None
diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py
new file mode 100644
index 0000000..b890d77
--- /dev/null
+++ b/psycopg/psycopg/server_cursor.py
@@ -0,0 +1,479 @@
+"""
+psycopg server-side cursor objects.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, AsyncIterator, List, Iterable, Iterator
+from typing import Optional, TypeVar, TYPE_CHECKING, overload
+from warnings import warn
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .rows import Row, RowFactory, AsyncRowFactory
+from .cursor import BaseCursor, Cursor
+from .generators import execute
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+DEFAULT_ITERSIZE = 100
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+
+class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
+ """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
+
+ __slots__ = "_name _scrollable _withhold _described itersize _format".split()
+
+ def __init__(
+ self,
+ name: str,
+ scrollable: Optional[bool],
+ withhold: bool,
+ ):
+ self._name = name
+ self._scrollable = scrollable
+ self._withhold = withhold
+ self._described = False
+ self.itersize: int = DEFAULT_ITERSIZE
+ self._format = TEXT
+
+ def __repr__(self) -> str:
+ # Insert the name as the second word
+ parts = super().__repr__().split(None, 1)
+ parts.insert(1, f"{self._name!r}")
+ return " ".join(parts)
+
+ @property
+ def name(self) -> str:
+ """The name of the cursor."""
+ return self._name
+
+ @property
+ def scrollable(self) -> Optional[bool]:
+ """
+ Whether the cursor is scrollable or not.
+
+ If `!None` leave the choice to the server. Use `!True` if you want to
+ use `scroll()` on the cursor.
+ """
+ return self._scrollable
+
+ @property
+ def withhold(self) -> bool:
+ """
+ If the cursor can be used after the creating transaction has committed.
+ """
+ return self._withhold
+
+ @property
+ def rownumber(self) -> Optional[int]:
+ """Index of the next row to fetch in the current result.
+
+ `!None` if there is no result to fetch.
+ """
+ res = self.pgresult
+ # command_status is empty if the result comes from
+ # describe_portal, which means that we have just executed the DECLARE,
+ # so we can assume we are at the first row.
+ tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
+ return self._pos if tuples else None
+
+ def _declare_gen(
+ self,
+ query: Query,
+ params: Optional[Params] = None,
+ binary: Optional[bool] = None,
+ ) -> PQGen[None]:
+ """Generator implementing `ServerCursor.execute()`."""
+
+ query = self._make_declare_statement(query)
+
+ # If the cursor is being reused, the previous one must be closed.
+ if self._described:
+ yield from self._close_gen()
+ self._described = False
+
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, force_extended=True)
+ results = yield from execute(self._conn.pgconn)
+ if results[-1].status != COMMAND_OK:
+ self._raise_for_result(results[-1])
+
+ # Set the format, which will be used by describe and fetch operations
+ if binary is None:
+ self._format = self.format
+ else:
+ self._format = BINARY if binary else TEXT
+
+ # The above result only returned COMMAND_OK. Get the cursor shape
+ yield from self._describe_gen()
+
+ def _describe_gen(self) -> PQGen[None]:
+ self._pgconn.send_describe_portal(self._name.encode(self._encoding))
+ results = yield from execute(self._pgconn)
+ self._check_results(results)
+ self._results = results
+ self._select_current_result(0, format=self._format)
+ self._described = True
+
+ def _close_gen(self) -> PQGen[None]:
+ ts = self._conn.pgconn.transaction_status
+
+ # if the connection is not in a sane state, don't even try
+ if ts != IDLE and ts != INTRANS:
+ return
+
+ # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
+ if not self._withhold and ts == IDLE:
+ return
+
+ # if we didn't declare the cursor ourselves we still have to close it
+ # but we must make sure it exists.
+ if not self._described:
+ query = sql.SQL(
+ "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
+ ).format(sql.Literal(self._name))
+ res = yield from self._conn._exec_command(query)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+ if res.ntuples == 0:
+ return
+
+ query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
+ yield from self._conn._exec_command(query)
+
+ def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
+ if self.closed:
+ raise e.InterfaceError("the cursor is closed")
+ # If we are stealing the cursor, make sure we know its shape
+ if not self._described:
+ yield from self._start_query()
+ yield from self._describe_gen()
+
+ query = sql.SQL("FETCH FORWARD {} FROM {}").format(
+ sql.SQL("ALL") if num is None else sql.Literal(num),
+ sql.Identifier(self._name),
+ )
+ res = yield from self._conn._exec_command(query, result_format=self._format)
+ # pipeline mode otherwise, unsupported here.
+ assert res is not None
+
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=False)
+ return self._tx.load_rows(0, res.ntuples, self._make_row)
+
+ def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+ if mode not in ("relative", "absolute"):
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+ query = sql.SQL("MOVE{} {} FROM {}").format(
+ sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
+ sql.Literal(value),
+ sql.Identifier(self._name),
+ )
+ yield from self._conn._exec_command(query)
+
+ def _make_declare_statement(self, query: Query) -> sql.Composed:
+
+ if isinstance(query, bytes):
+ query = query.decode(self._encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
+
+ parts = [
+ sql.SQL("DECLARE"),
+ sql.Identifier(self._name),
+ ]
+ if self._scrollable is not None:
+ parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
+ parts.append(sql.SQL("CURSOR"))
+ if self._withhold:
+ parts.append(sql.SQL("WITH HOLD"))
+ parts.append(sql.SQL("FOR"))
+ parts.append(query)
+
+ return sql.SQL(" ").join(parts)
+
+
+class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="ServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "ServerCursor[Row]",
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: RowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "Connection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[RowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ Cursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ def close(self) -> None:
+ """
+ Close the current cursor and free associated resources.
+ """
+ with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ self._conn.wait(self._close_gen())
+ super().close()
+
+ def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ """
+ Open a cursor to execute a query to the database.
+ """
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ with self._conn.lock:
+ self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ """Method not implemented for server-side cursors."""
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ def fetchone(self) -> Optional[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ def fetchall(self) -> List[Row]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ def __iter__(self) -> Iterator[Row]:
+ while True:
+ with self._conn.lock:
+ recs = self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ def scroll(self, value: int, mode: str = "relative") -> None:
+ with self._conn.lock:
+ self._conn.wait(self._scroll_gen(value, mode))
+ # Postgres doesn't have a reliable way to report a cursor out of bound
+ if mode == "relative":
+ self._pos += value
+ else:
+ self._pos = value
+
+
+class AsyncServerCursor(
+ ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
+):
+ __module__ = "psycopg"
+ __slots__ = ()
+ _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Row]",
+ name: str,
+ *,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncServerCursor[Row]",
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: AsyncRowFactory[Row],
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ ...
+
+ def __init__(
+ self,
+ connection: "AsyncConnection[Any]",
+ name: str,
+ *,
+ row_factory: Optional[AsyncRowFactory[Row]] = None,
+ scrollable: Optional[bool] = None,
+ withhold: bool = False,
+ ):
+ AsyncCursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
+
+ def __del__(self) -> None:
+ if not self.closed:
+ warn(
+ f"the server-side cursor {self} was deleted while still open."
+ " Please use 'with' or '.close()' to close the cursor properly",
+ ResourceWarning,
+ )
+
+ async def close(self) -> None:
+ async with self._conn.lock:
+ if self.closed:
+ return
+ if not self._conn.closed:
+ await self._conn.wait(self._close_gen())
+ await super().close()
+
+ async def execute(
+ self: _Self,
+ query: Query,
+ params: Optional[Params] = None,
+ *,
+ binary: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> _Self:
+ if kwargs:
+ raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
+ if self._pgconn.pipeline_status:
+ raise e.NotSupportedError(
+ "server-side cursors not supported in pipeline mode"
+ )
+
+ try:
+ async with self._conn.lock:
+ await self._conn.wait(self._declare_gen(query, params, binary))
+ except e.Error as ex:
+ raise ex.with_traceback(None)
+
+ return self
+
+ async def executemany(
+ self,
+ query: Query,
+ params_seq: Iterable[Params],
+ *,
+ returning: bool = True,
+ ) -> None:
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
+
+ async def fetchone(self) -> Optional[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(1))
+ if recs:
+ self._pos += 1
+ return recs[0]
+ else:
+ return None
+
+ async def fetchmany(self, size: int = 0) -> List[Row]:
+ if not size:
+ size = self.arraysize
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(size))
+ self._pos += len(recs)
+ return recs
+
+ async def fetchall(self) -> List[Row]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(None))
+ self._pos += len(recs)
+ return recs
+
+ async def __aiter__(self) -> AsyncIterator[Row]:
+ while True:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._fetch_gen(self.itersize))
+ for rec in recs:
+ self._pos += 1
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
+ async def scroll(self, value: int, mode: str = "relative") -> None:
+ async with self._conn.lock:
+ await self._conn.wait(self._scroll_gen(value, mode))
diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py
new file mode 100644
index 0000000..099a01c
--- /dev/null
+++ b/psycopg/psycopg/sql.py
@@ -0,0 +1,467 @@
+"""
+SQL composition utility module
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+import string
+from abc import ABC, abstractmethod
+from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
+
+from .pq import Escaping
+from .abc import AdaptContext
+from .adapt import Transformer, PyFormat
+from ._compat import LiteralString
+from ._encodings import conn_encoding
+
+
+def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
+ """
+ Adapt a Python object to a quoted SQL string.
+
+ Use this function only if you absolutely want to convert a Python string to
+ an SQL quoted literal to use e.g. to generate batch SQL and you won't have
+ a connection available when you will need to use it.
+
+ This function is relatively inefficient, because it doesn't cache the
+ adaptation rules. If you pass a `!context` you can adapt the adaptation
+ rules used, otherwise only global rules are used.
+
+ """
+ return Literal(obj).as_string(context)
+
+
+class Composable(ABC):
+ """
+ Abstract base class for objects that can be used to compose an SQL string.
+
+ `!Composable` objects can be passed directly to
+ `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
+ `~psycopg.Cursor.copy()` in place of the query string.
+
+ `!Composable` objects can be joined using the ``+`` operator: the result
+ will be a `Composed` instance containing the objects joined. The operator
+ ``*`` is also supported with an integer argument: the result is a
+ `!Composed` instance containing the left argument repeated as many times as
+ requested.
+ """
+
+ def __init__(self, obj: Any):
+ self._obj = obj
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._obj!r})"
+
+ @abstractmethod
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ """
+ Return the value of the object as bytes.
+
+ :param context: the context to evaluate the object into.
+ :type context: `connection` or `cursor`
+
+ The method is automatically invoked by `~psycopg.Cursor.execute()`,
+ `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
+ `!Composable` is passed instead of the query string.
+
+ """
+ raise NotImplementedError
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ """
+ Return the value of the object as string.
+
+ :param context: the context to evaluate the string into.
+ :type context: `connection` or `cursor`
+
+ """
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ b = self.as_bytes(context)
+ if isinstance(b, bytes):
+ return b.decode(enc)
+ else:
+ # buffer object
+ return codecs.lookup(enc).decode(b)[0]
+
+ def __add__(self, other: "Composable") -> "Composed":
+ if isinstance(other, Composed):
+ return Composed([self]) + other
+ if isinstance(other, Composable):
+ return Composed([self]) + Composed([other])
+ else:
+ return NotImplemented
+
+ def __mul__(self, n: int) -> "Composed":
+ return Composed([self] * n)
+
+ def __eq__(self, other: Any) -> bool:
+ return type(self) is type(other) and self._obj == other._obj
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Composed(Composable):
+ """
+ A `Composable` object made of a sequence of `!Composable`.
+
+ The object is usually created using `!Composable` operators and methods.
+ However it is possible to create a `!Composed` directly specifying a
+ sequence of objects as arguments: if they are not `!Composable` they will
+ be wrapped in a `Literal`.
+
+ Example::
+
+ >>> comp = sql.Composed(
+ ... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
+ >>> print(comp.as_string(conn))
+ INSERT INTO "table"
+
+ `!Composed` objects are iterable (so they can be used in `SQL.join` for
+ instance).
+ """
+
+ _obj: List[Composable]
+
+ def __init__(self, seq: Sequence[Any]):
+ seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
+ super().__init__(seq)
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ return b"".join(obj.as_bytes(context) for obj in self._obj)
+
+ def __iter__(self) -> Iterator[Composable]:
+ return iter(self._obj)
+
+ def __add__(self, other: Composable) -> "Composed":
+ if isinstance(other, Composed):
+ return Composed(self._obj + other._obj)
+ if isinstance(other, Composable):
+ return Composed(self._obj + [other])
+ else:
+ return NotImplemented
+
+ def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
+ """
+ Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
+
+ The `!joiner` must be a `SQL` or a string which will be interpreted as
+ an `SQL`.
+
+ Example::
+
+ >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
+ >>> print(fields.join(', ').as_string(conn))
+ "foo", "bar"
+
+ """
+ if isinstance(joiner, str):
+ joiner = SQL(joiner)
+ elif not isinstance(joiner, SQL):
+ raise TypeError(
+ "Composed.join() argument must be strings or SQL,"
+ f" got {joiner!r} instead"
+ )
+
+ return joiner.join(self._obj)
+
+
+class SQL(Composable):
+ """
+ A `Composable` representing a snippet of SQL statement.
+
+ `!SQL` exposes `join()` and `format()` methods useful to create a template
+ where to merge variable parts of a query (for instance field or table
+ names).
+
+ The `!obj` string doesn't undergo any form of escaping, so it is not
+ suitable to represent variable identifiers or values: you should only use
+ it to pass constant strings representing templates or snippets of SQL
+ statements; use other objects such as `Identifier` or `Literal` to
+ represent variable parts.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {0} FROM {1}").format(
+ ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
+ ... sql.Identifier('table'))
+ >>> print(query.as_string(conn))
+ SELECT "foo", "bar" FROM "table"
+ """
+
+ _obj: LiteralString
+ _formatter = string.Formatter()
+
+ def __init__(self, obj: LiteralString):
+ super().__init__(obj)
+ if not isinstance(obj, str):
+ raise TypeError(f"SQL values must be strings, got {obj!r} instead")
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ return self._obj
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ enc = "utf-8"
+ if context:
+ enc = conn_encoding(context.connection)
+ return self._obj.encode(enc)
+
+ def format(self, *args: Any, **kwargs: Any) -> Composed:
+ """
+ Merge `Composable` objects into a template.
+
+ :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
+ auto-numbered (``{}``) placeholders
+ :param kwargs: parameters to replace to named (``{name}``) placeholders
+ :return: the union of the `!SQL` string with placeholders replaced
+ :rtype: `Composed`
+
+ The method is similar to the Python `str.format()` method: the string
+ template supports auto-numbered (``{}``), numbered (``{0}``,
+ ``{1}``...), and named placeholders (``{name}``), with positional
+ arguments replacing the numbered placeholders and keywords replacing
+ the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
+ are not supported.
+
+ If a `!Composable` objects is passed to the template it will be merged
+ according to its `as_string()` method. If any other Python object is
+ passed, it will be wrapped in a `Literal` object and so escaped
+ according to SQL rules.
+
+ Example::
+
+ >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
+ ... .format(sql.Identifier('people'), sql.Identifier('id'))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE "id" = %s
+
+ >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
+ ... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
+ ... .as_string(conn))
+ SELECT * FROM "people" WHERE name = 'O''Rourke'
+
+ """
+ rv: List[Composable] = []
+ autonum: Optional[int] = 0
+ # TODO: this is probably not the right way to whitelist pre
+ # pyre complains. Will wait for mypy to complain too to fix.
+ pre: LiteralString
+ for pre, name, spec, conv in self._formatter.parse(self._obj):
+ if spec:
+ raise ValueError("no format specification supported by SQL")
+ if conv:
+ raise ValueError("no format conversion supported by SQL")
+ if pre:
+ rv.append(SQL(pre))
+
+ if name is None:
+ continue
+
+ if name.isdigit():
+ if autonum:
+ raise ValueError(
+ "cannot switch from automatic field numbering to manual"
+ )
+ rv.append(args[int(name)])
+ autonum = None
+
+ elif not name:
+ if autonum is None:
+ raise ValueError(
+ "cannot switch from manual field numbering to automatic"
+ )
+ rv.append(args[autonum])
+ autonum += 1
+
+ else:
+ rv.append(kwargs[name])
+
+ return Composed(rv)
+
+ def join(self, seq: Iterable[Composable]) -> Composed:
+ """
+ Join a sequence of `Composable`.
+
+ :param seq: the elements to join.
+ :type seq: iterable of `!Composable`
+
+ Use the `!SQL` object's string to separate the elements in `!seq`.
+ Note that `Composed` objects are iterable too, so they can be used as
+ argument for this method.
+
+ Example::
+
+ >>> snip = sql.SQL(', ').join(
+ ... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
+ >>> print(snip.as_string(conn))
+ "foo", "bar", "baz"
+ """
+ rv = []
+ it = iter(seq)
+ try:
+ rv.append(next(it))
+ except StopIteration:
+ pass
+ else:
+ for i in it:
+ rv.append(self)
+ rv.append(i)
+
+ return Composed(rv)
+
+
+class Identifier(Composable):
+ """
+ A `Composable` representing an SQL identifier or a dot-separated sequence.
+
+ Identifiers usually represent names of database objects, such as tables or
+ fields. PostgreSQL identifiers follow `different rules`__ than SQL string
+ literals for escaping (e.g. they use double quotes instead of single).
+
+ .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
+ SQL-SYNTAX-IDENTIFIERS
+
+ Example::
+
+ >>> t1 = sql.Identifier("foo")
+ >>> t2 = sql.Identifier("ba'r")
+ >>> t3 = sql.Identifier('ba"z')
+ >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
+ "foo", "ba'r", "ba""z"
+
+ Multiple strings can be passed to the object to represent a qualified name,
+ i.e. a dot-separated sequence of identifiers.
+
+ Example::
+
+ >>> query = sql.SQL("SELECT {} FROM {}").format(
+ ... sql.Identifier("table", "field"),
+ ... sql.Identifier("schema", "table"))
+ >>> print(query.as_string(conn))
+ SELECT "table"."field" FROM "schema"."table"
+
+ """
+
+ _obj: Sequence[str]
+
+ def __init__(self, *strings: str):
+ # init super() now to make the __repr__ not explode in case of error
+ super().__init__(strings)
+
+ if not strings:
+ raise TypeError("Identifier cannot be empty")
+
+ for s in strings:
+ if not isinstance(s, str):
+ raise TypeError(
+ f"SQL identifier parts must be strings, got {s!r} instead"
+ )
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ if not conn:
+ raise ValueError("a connection is necessary for Identifier")
+ esc = Escaping(conn.pgconn)
+ enc = conn_encoding(conn)
+ escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ return b".".join(escs)
+
+
+class Literal(Composable):
+ """
+ A `Composable` representing an SQL value to include in a query.
+
+ Usually you will want to include placeholders in the query and pass values
+ as `~cursor.execute()` arguments. If however you really really need to
+ include a literal value in the query you can use this object.
+
+ The string returned by `!as_string()` follows the normal :ref:`adaptation
+ rules <types-adaptation>` for Python objects.
+
+ Example::
+
+ >>> s1 = sql.Literal("fo'o")
+ >>> s2 = sql.Literal(42)
+ >>> s3 = sql.Literal(date(2000, 1, 1))
+ >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
+ 'fo''o', 42, '2000-01-01'::date
+
+ """
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ tx = Transformer.from_context(context)
+ return tx.as_literal(self._obj)
+
+
+class Placeholder(Composable):
+ """A `Composable` representing a placeholder for query parameters.
+
+ If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
+ ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
+ ``%b``).
+
+ The object is useful to generate SQL queries with a variable number of
+ arguments.
+
+ Examples::
+
+ >>> names = ['foo', 'bar', 'baz']
+
+ >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(sql.Placeholder() * len(names)))
+ >>> print(q1.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
+
+ >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(map(sql.Placeholder, names)))
+ >>> print(q2.as_string(conn))
+ INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
+
+ """
+
+ def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
+ super().__init__(name)
+ if not isinstance(name, str):
+ raise TypeError(f"expected string as name, got {name!r}")
+
+ if ")" in name:
+ raise ValueError(f"invalid name: {name!r}")
+
+ if type(format) is str:
+ format = PyFormat(format)
+ if not isinstance(format, PyFormat):
+ raise TypeError(
+ f"expected PyFormat as format, got {type(format).__name__!r}"
+ )
+
+ self._format: PyFormat = format
+
+ def __repr__(self) -> str:
+ parts = []
+ if self._obj:
+ parts.append(repr(self._obj))
+ if self._format is not PyFormat.AUTO:
+ parts.append(f"format={self._format.name}")
+
+ return f"{self.__class__.__name__}({', '.join(parts)})"
+
+ def as_string(self, context: Optional[AdaptContext]) -> str:
+ code = self._format.value
+ return f"%({self._obj}){code}" if self._obj else f"%{code}"
+
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ enc = conn_encoding(conn)
+ return self.as_string(context).encode(enc)
+
+
+# Literals
+NULL = SQL("NULL")
+DEFAULT = SQL("DEFAULT")
diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py
new file mode 100644
index 0000000..e13486e
--- /dev/null
+++ b/psycopg/psycopg/transaction.py
@@ -0,0 +1,290 @@
+"""
+Transaction context managers returned by Connection.transaction()
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from types import TracebackType
+from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
+
+from . import pq
+from . import sql
+from . import errors as e
+from .abc import ConnectionType, PQGen
+
+if TYPE_CHECKING:
+ from typing import Any
+ from .connection import Connection
+ from .connection_async import AsyncConnection
+
+IDLE = pq.TransactionStatus.IDLE
+
+OK = pq.ConnStatus.OK
+
+logger = logging.getLogger(__name__)
+
+
+class Rollback(Exception):
+ """
+ Exit the current `Transaction` context immediately and rollback any changes
+ made within this context.
+
+ If a transaction context is specified in the constructor, rollback
+ enclosing transactions contexts up to and including the one specified.
+ """
+
+ __module__ = "psycopg"
+
+ def __init__(
+ self,
+ transaction: Union["Transaction", "AsyncTransaction", None] = None,
+ ):
+ self.transaction = transaction
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__qualname__}({self.transaction!r})"
+
+
+class OutOfOrderTransactionNesting(e.ProgrammingError):
+ """Out-of-order transaction nesting detected"""
+
+
+class BaseTransaction(Generic[ConnectionType]):
+ def __init__(
+ self,
+ connection: ConnectionType,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ):
+ self._conn = connection
+ self.pgconn = self._conn.pgconn
+ self._savepoint_name = savepoint_name or ""
+ self.force_rollback = force_rollback
+ self._entered = self._exited = False
+ self._outer_transaction = False
+ self._stack_index = -1
+
+ @property
+ def savepoint_name(self) -> Optional[str]:
+ """
+ The name of the savepoint; `!None` if handling the main transaction.
+ """
+ # Yes, it may change on __enter__. No, I don't care, because the
+ # un-entered state is outside the public interface.
+ return self._savepoint_name
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = pq.misc.connection_summary(self.pgconn)
+ if not self._entered:
+ status = "inactive"
+ elif not self._exited:
+ status = "active"
+ else:
+ status = "terminated"
+
+ sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
+ return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
+
+ def _enter_gen(self) -> PQGen[None]:
+ if self._entered:
+ raise TypeError("transaction blocks can be used only once")
+ self._entered = True
+
+ self._push_savepoint()
+ for command in self._get_enter_commands():
+ yield from self._conn._exec_command(command)
+
+ def _exit_gen(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> PQGen[bool]:
+ if not exc_val and not self.force_rollback:
+ yield from self._commit_gen()
+ return False
+ else:
+ # try to rollback, but if there are problems (connection in a bad
+ # state) just warn without clobbering the exception bubbling up.
+ try:
+ return (yield from self._rollback_gen(exc_val))
+ except OutOfOrderTransactionNesting:
+ # Clobber an exception happened in the block with the exception
+ # caused by out-of-order transaction detected, so make the
+ # behaviour consistent with _commit_gen and to make sure the
+ # user fixes this condition, which is unrelated from
+ # operational error that might arise in the block.
+ raise
+ except Exception as exc2:
+ logger.warning("error ignored in rollback of %s: %s", self, exc2)
+ return False
+
+ def _commit_gen(self) -> PQGen[None]:
+ ex = self._pop_savepoint("commit")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_commit_commands():
+ yield from self._conn._exec_command(command)
+
+ def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
+ if isinstance(exc_val, Rollback):
+ logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
+
+ ex = self._pop_savepoint("rollback")
+ self._exited = True
+ if ex:
+ raise ex
+
+ for command in self._get_rollback_commands():
+ yield from self._conn._exec_command(command)
+
+ if isinstance(exc_val, Rollback):
+ if not exc_val.transaction or exc_val.transaction is self:
+ return True # Swallow the exception
+
+ return False
+
+ def _get_enter_commands(self) -> Iterator[bytes]:
+ if self._outer_transaction:
+ yield self._conn._get_tx_start_command()
+
+ if self._savepoint_name:
+ yield (
+ sql.SQL("SAVEPOINT {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ def _get_commit_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("RELEASE {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"COMMIT"
+
+ def _get_rollback_commands(self) -> Iterator[bytes]:
+ if self._savepoint_name and not self._outer_transaction:
+ yield (
+ sql.SQL("ROLLBACK TO {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+ yield (
+ sql.SQL("RELEASE {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._num_transactions
+ yield b"ROLLBACK"
+
+ # Also clear the prepared statements cache.
+ if self._conn._prepared.clear():
+ yield from self._conn._prepared.get_maintenance_commands()
+
+ def _push_savepoint(self) -> None:
+ """
+ Push the transaction on the connection transactions stack.
+
+ Also set the internal state of the object and verify consistency.
+ """
+ self._outer_transaction = self.pgconn.transaction_status == IDLE
+ if self._outer_transaction:
+ # outer transaction: if no name it's only a begin, else
+ # there will be an additional savepoint
+ assert not self._conn._num_transactions
+ else:
+ # inner transaction: it always has a name
+ if not self._savepoint_name:
+ self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
+
+ self._stack_index = self._conn._num_transactions
+ self._conn._num_transactions += 1
+
+ def _pop_savepoint(self, action: str) -> Optional[Exception]:
+ """
+ Pop the transaction from the connection transactions stack.
+
+ Also verify the state consistency.
+ """
+ self._conn._num_transactions -= 1
+ if self._conn._num_transactions == self._stack_index:
+ return None
+
+ return OutOfOrderTransactionNesting(
+ f"transaction {action} at the wrong nesting level: {self}"
+ )
+
+
+class Transaction(BaseTransaction["Connection[Any]"]):
+ """
+ Returned by `Connection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="Transaction")
+
+ @property
+ def connection(self) -> "Connection[Any]":
+ """The connection the object is managing."""
+ return self._conn
+
+ def __enter__(self: _Self) -> _Self:
+ with self._conn.lock:
+ self._conn.wait(self._enter_gen())
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ with self._conn.lock:
+ return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
+
+
+class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
+ """
+ Returned by `AsyncConnection.transaction()` to handle a transaction block.
+ """
+
+ __module__ = "psycopg"
+
+ _Self = TypeVar("_Self", bound="AsyncTransaction")
+
+ @property
+ def connection(self) -> "AsyncConnection[Any]":
+ return self._conn
+
+ async def __aenter__(self: _Self) -> _Self:
+ async with self._conn.lock:
+ await self._conn.wait(self._enter_gen())
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self.pgconn.status == OK:
+ async with self._conn.lock:
+ return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
+ else:
+ return False
diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py
new file mode 100644
index 0000000..bdddf05
--- /dev/null
+++ b/psycopg/psycopg/types/__init__.py
@@ -0,0 +1,11 @@
+"""
+psycopg types package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import _typeinfo
+
+# Exposed here
+TypeInfo = _typeinfo.TypeInfo
+TypesRegistry = _typeinfo.TypesRegistry
diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py
new file mode 100644
index 0000000..e35c5e7
--- /dev/null
+++ b/psycopg/psycopg/types/array.py
@@ -0,0 +1,464 @@
+"""
+Adapters for arrays
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
+
+from .. import pq
+from .. import errors as e
+from .. import postgres
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._compat import cache, prod
+from .._struct import pack_len, unpack_len
+from .._cmodule import _psycopg
+from ..postgres import TEXT_OID, INVALID_OID
+from .._typeinfo import TypeInfo
+
+_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
+_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
+_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
+_struct_dim = struct.Struct("!II") # dim, lower bound
+_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
+_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
+
+TEXT_ARRAY_OID = postgres.types["text"].array_oid
+
+PY_TEXT = PyFormat.TEXT
+PQ_BINARY = pq.Format.BINARY
+
+
+class BaseListDumper(RecursiveDumper):
+ element_oid = 0
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ if cls is NoneType:
+ cls = list
+
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ if self.element_oid and context:
+ sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
+ self.sub_dumper = sdclass(NoneType, context)
+
+ def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ items = list(self._flatiter(L, set()))
+ types = {type(item): item for item in items}
+ if not types:
+ return None
+
+ if len(types) == 1:
+ t, v = types.popitem()
+ else:
+ # More than one type in the list. It might be still good, as long
+ # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
+ dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
+ oids = set(d.oid for d in dumpers)
+ if len(oids) == 1:
+ t, v = types.popitem()
+ else:
+ raise e.DataError(
+ "cannot dump lists of mixed types;"
+ f" got: {', '.join(sorted(t.__name__ for t in types))}"
+ )
+
+ # Checking for precise type. If the type is a subclass (e.g. Int4)
+ # we assume the user knows what type they are passing.
+ if t is not int:
+ return v
+
+ # If we got an int, let's see what is the biggest one in order to
+ # choose the smallest OID and allow Postgres to do the right cast.
+ imax: int = max(items)
+ imin: int = min(items)
+ if imin >= 0:
+ return imax
+ else:
+ return max(imax, -imin - 1)
+
+ def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
+ if id(L) in seen:
+ raise e.DataError("cannot dump a recursive list")
+
+ seen.add(id(L))
+
+ for item in L:
+ if type(item) is list:
+ yield from self._flatiter(item, seen)
+ elif item is not None:
+ yield item
+
+ return None
+
+ def _get_base_type_info(self, base_oid: int) -> TypeInfo:
+ """
+ Return info about the base type.
+
+ Return text info as fallback.
+ """
+ if base_oid:
+ info = self._tx.adapters.types.get(base_oid)
+ if info:
+ return info
+
+ return self._tx.adapters.types["text"]
+
+
+class ListDumper(BaseListDumper):
+
+ delimiter = b","
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return self.cls
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ # Empty lists can only be dumped as text if the type is unknown.
+ return self
+
+ sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+
+ # We consider an array of unknowns as unknown, so we can dump empty
+ # lists or lists containing only None elements.
+ if sd.oid != INVALID_OID:
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+ dumper.delimiter = info.delimiter.encode()
+ else:
+ dumper.oid = INVALID_OID
+
+ return dumper
+
+ # Double quotes and backslashes embedded in element values will be
+ # backslash-escaped.
+ _re_esc = re.compile(rb'(["\\])')
+
+ def dump(self, obj: List[Any]) -> bytes:
+ tokens: List[Buffer] = []
+ needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
+
+ def dump_list(obj: List[Any]) -> None:
+ if not obj:
+ tokens.append(b"{}")
+ return
+
+ tokens.append(b"{")
+ for item in obj:
+ if isinstance(item, list):
+ dump_list(item)
+ elif item is not None:
+ ad = self._dump_item(item)
+ if needs_quotes(ad):
+ if not isinstance(ad, bytes):
+ ad = bytes(ad)
+ ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
+ tokens.append(ad)
+ else:
+ tokens.append(b"NULL")
+
+ tokens.append(self.delimiter)
+
+ tokens[-1] = b"}"
+
+ dump_list(obj)
+
+ return b"".join(tokens)
+
+ def _dump_item(self, item: Any) -> Buffer:
+ if self.sub_dumper:
+ return self.sub_dumper.dump(item)
+ else:
+ return self._tx.get_dumper(item, PY_TEXT).dump(item)
+
+
+@cache
+def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """Return a regexp to recognise when a value needs quotes
+
+ from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
+
+ The array output routine will put double quotes around element values if
+ they are empty strings, contain curly braces, delimiter characters,
+ double quotes, backslashes, or white space, or match the word NULL.
+ """
+ return re.compile(
+ rb"""(?xi)
+ ^$ # the empty string
+ | ["{}%s\\\s] # or a char to escape
+ | ^null$ # or the word NULL
+ """
+ % delimiter
+ )
+
+
+class ListBinaryDumper(BaseListDumper):
+
+ format = pq.Format.BINARY
+
+ def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
+ if self.oid:
+ return self.cls
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return (self.cls,)
+
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.get_key(item, format))
+
+ def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
+ # If we have an oid we don't need to upgrade
+ if self.oid:
+ return self
+
+ item = self._find_list_element(obj, format)
+ if item is None:
+ return ListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format.from_pq(self.format))
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ info = self._get_base_type_info(sd.oid)
+ dumper.oid = info.array_oid or TEXT_ARRAY_OID
+
+ return dumper
+
+ def dump(self, obj: List[Any]) -> bytes:
+ # Postgres won't take unknown for element oid: fall back on text
+ sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
+
+ if not obj:
+ return _pack_head(0, 0, sub_oid)
+
+ data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
+ dims: List[int] = []
+ hasnull = 0
+
+ def calc_dims(L: List[Any]) -> None:
+ if isinstance(L, self.cls):
+ if not L:
+ raise e.DataError("lists cannot contain empty lists")
+ dims.append(len(L))
+ calc_dims(L[0])
+
+ calc_dims(obj)
+
+ def dump_list(L: List[Any], dim: int) -> None:
+ nonlocal hasnull
+ if len(L) != dims[dim]:
+ raise e.DataError("nested lists have inconsistent lengths")
+
+ if dim == len(dims) - 1:
+ for item in L:
+ if item is not None:
+ # If we get here, the sub_dumper must have been set
+ ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
+ data.append(pack_len(len(ad)))
+ data.append(ad)
+ else:
+ hasnull = 1
+ data.append(b"\xff\xff\xff\xff")
+ else:
+ for item in L:
+ if not isinstance(item, self.cls):
+ raise e.DataError("nested lists have inconsistent depths")
+ dump_list(item, dim + 1) # type: ignore
+
+ dump_list(obj, 0)
+
+ data[0] = _pack_head(len(dims), hasnull, sub_oid)
+ data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
+ return b"".join(data)
+
+
+class ArrayLoader(RecursiveLoader):
+
+ delimiter = b","
+ base_oid: int
+
+ def load(self, data: Buffer) -> List[Any]:
+ loader = self._tx.get_loader(self.base_oid, self.format)
+ return _load_text(data, loader, self.delimiter)
+
+
+class ArrayBinaryLoader(RecursiveLoader):
+
+ format = pq.Format.BINARY
+
+ def load(self, data: Buffer) -> List[Any]:
+ return _load_binary(data, self._tx)
+
+
+def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ if not info.array_oid:
+ raise ValueError(f"the type info {info} doesn't describe an array")
+
+ base: Type[Any]
+ adapters = context.adapters if context else postgres.adapters
+
+ base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "base_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ loader = type(name, (base,), attribs)
+ adapters.register_loader(info.array_oid, loader)
+
+ loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
+ adapters.register_loader(info.array_oid, loader)
+
+ base = ListDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ "delimiter": info.delimiter.encode(),
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+ base = ListBinaryDumper
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {
+ "oid": info.array_oid,
+ "element_oid": info.oid,
+ }
+ dumper = type(name, (base,), attribs)
+ adapters.register_dumper(None, dumper)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ # The text dumper is more flexible as it can handle lists of mixed type,
+ # so register it later.
+ context.adapters.register_dumper(list, ListBinaryDumper)
+ context.adapters.register_dumper(list, ListDumper)
+
+
+def register_all_arrays(context: AdaptContext) -> None:
+ """
+ Associate the array oid of all the types in Loader.globals.
+
+ This function is designed to be called once at import time, after having
+ registered all the base loaders.
+ """
+ for t in context.adapters.types:
+ if t.array_oid:
+ t.register(context)
+
+
+def _load_text(
+ data: Buffer,
+ loader: Loader,
+ delimiter: bytes = b",",
+ __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
+) -> List[Any]:
+ rv = None
+ stack: List[Any] = []
+ a: List[Any] = []
+ rv = a
+ load = loader.load
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if data and data[0] == b"["[0]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ idx = data.find(b"=")
+ if idx == -1:
+ raise e.DataError("malformed array: no '=' after dimension information")
+ data = data[idx + 1 :]
+
+ re_parse = _get_array_parse_regexp(delimiter)
+ for m in re_parse.finditer(data):
+ t = m.group(1)
+ if t == b"{":
+ if stack:
+ stack[-1].append(a)
+ stack.append(a)
+ a = []
+
+ elif t == b"}":
+ if not stack:
+ raise e.DataError("malformed array: unexpected '}'")
+ rv = stack.pop()
+
+ else:
+ if not stack:
+ wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
+ raise e.DataError(f"malformed array: unexpected '{wat}'")
+ if t == b"NULL":
+ v = None
+ else:
+ if t.startswith(b'"'):
+ t = __re_unescape.sub(rb"\1", t[1:-1])
+ v = load(t)
+
+ stack[-1].append(v)
+
+ assert rv is not None
+ return rv
+
+
+@cache
+def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
+ """
+ Return a regexp to tokenize an array representation into item and brackets
+ """
+ return re.compile(
+ rb"""(?xi)
+ ( [{}] # open or closed bracket
+ | " (?: [^"\\] | \\. )* " # or a quoted string
+ | [^"{}%s\\]+ # or an unquoted non-empty string
+ ) ,?
+ """
+ % delimiter
+ )
+
+
+def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
+ ndims, hasnull, oid = _unpack_head(data)
+ load = tx.get_loader(oid, PQ_BINARY).load
+
+ if not ndims:
+ return []
+
+ p = 12 + 8 * ndims
+ dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
+ nelems = prod(dims)
+
+ out: List[Any] = [None] * nelems
+ for i in range(nelems):
+ size = unpack_len(data, p)[0]
+ p += 4
+ if size == -1:
+ continue
+ out[i] = load(data[p : p + size])
+ p += size
+
+ # fon ndims > 1 we have to aggregate the array into sub-arrays
+ for dim in dims[-1:0:-1]:
+ out = [out[i : i + dim] for i in range(0, len(out), dim)]
+
+ return out
diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py
new file mode 100644
index 0000000..db7e181
--- /dev/null
+++ b/psycopg/psycopg/types/bool.py
@@ -0,0 +1,51 @@
+"""
+Adapters for booleans.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+
+class BoolDumper(Dumper):
+
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"t" if obj else b"f"
+
+ def quote(self, obj: bool) -> bytes:
+ return b"true" if obj else b"false"
+
+
+class BoolBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bool"].oid
+
+ def dump(self, obj: bool) -> bytes:
+ return b"\x01" if obj else b"\x00"
+
+
+class BoolLoader(Loader):
+ def load(self, data: Buffer) -> bool:
+ return data == b"t"
+
+
+class BoolBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> bool:
+ return data != b"\x00"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(bool, BoolDumper)
+ adapters.register_dumper(bool, BoolBinaryDumper)
+ adapters.register_loader("bool", BoolLoader)
+ adapters.register_loader("bool", BoolBinaryLoader)
diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py
new file mode 100644
index 0000000..1c609c3
--- /dev/null
+++ b/psycopg/psycopg/types/composite.py
@@ -0,0 +1,290 @@
+"""
+Support for composite types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from collections import namedtuple
+from typing import Any, Callable, cast, Iterator, List, Optional
+from typing import Sequence, Tuple, Type
+
+from .. import pq
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader
+from .._struct import pack_len, unpack_len
+from ..postgres import TEXT_OID
+from .._typeinfo import CompositeInfo as CompositeInfo # exported here
+from .._encodings import _as_python_identifier
+
+_struct_oidlen = struct.Struct("!Ii")
+_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
+_unpack_oidlen = cast(
+ Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
+)
+
+
+class SequenceDumper(RecursiveDumper):
+ def _dump_sequence(
+ self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
+ ) -> bytes:
+ if not obj:
+ return start + end
+
+ parts: List[Buffer] = [start]
+
+ for item in obj:
+ if item is None:
+ parts.append(sep)
+ continue
+
+ dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
+ ad = dumper.dump(item)
+ if not ad:
+ ad = b'""'
+ elif self._re_needs_quotes.search(ad):
+ ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"'
+
+ parts.append(ad)
+ parts.append(sep)
+
+ parts[-1] = end
+
+ return b"".join(parts)
+
+ _re_needs_quotes = re.compile(rb'[",\\\s()]')
+ _re_esc = re.compile(rb"([\\\"])")
+
+
+class TupleDumper(SequenceDumper):
+
+ # Should be this, but it doesn't work
+ # oid = postgres_types["record"].oid
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytes:
+ return self._dump_sequence(obj, b"(", b")", b",")
+
+
+class TupleBinaryDumper(RecursiveDumper):
+
+ format = pq.Format.BINARY
+
+ # Subclasses must set an info
+ info: CompositeInfo
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ nfields = len(self.info.field_types)
+ self._tx.set_dumper_types(self.info.field_types, self.format)
+ self._formats = (PyFormat.from_pq(self.format),) * nfields
+
+ def dump(self, obj: Tuple[Any, ...]) -> bytearray:
+ out = bytearray(pack_len(len(obj)))
+ adapted = self._tx.dump_sequence(obj, self._formats)
+ for i in range(len(obj)):
+ b = adapted[i]
+ oid = self.info.field_types[i]
+ if b is not None:
+ out += _pack_oidlen(oid, len(b))
+ out += b
+ else:
+ out += _pack_oidlen(oid, -1)
+
+ return out
+
+
+class BaseCompositeLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
+ """
+ Split a non-empty representation of a composite type into components.
+
+ Terminators shouldn't be used in `!data` (so that both record and range
+ representations can be parsed).
+ """
+ for m in self._re_tokenize.finditer(data):
+ if m.group(1):
+ yield None
+ elif m.group(2) is not None:
+ yield self._re_undouble.sub(rb"\1", m.group(2))
+ else:
+ yield m.group(3)
+
+ # If the final group ended in `,` there is a final NULL in the record
+ # that the regexp couldn't parse.
+ if m and m.group().endswith(b","):
+ yield None
+
+ _re_tokenize = re.compile(
+ rb"""(?x)
+ (,) # an empty token, representing NULL
+ | " ((?: [^"] | "")*) " ,? # or a quoted string
+ | ([^",)]+) ,? # or an unquoted string
+ """
+ )
+
+ _re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RecordLoader(BaseCompositeLoader):
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if data == b"()":
+ return ()
+
+ cast = self._tx.get_loader(TEXT_OID, self.format).load
+ return tuple(
+ cast(token) if token is not None else None
+ for token in self._parse_record(data[1:-1])
+ )
+
+
+class RecordBinaryLoader(Loader):
+ format = pq.Format.BINARY
+ _types_set = False
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._tx = Transformer(context)
+
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ return self._tx.load_sequence(
+ tuple(
+ data[offset : offset + length] if length != -1 else None
+ for _, offset, length in self._walk_record(data)
+ )
+ )
+
+ def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
+ """
+ Yield a sequence of (oid, offset, length) for the content of the record
+ """
+ nfields = unpack_len(data, 0)[0]
+ i = 4
+ for _ in range(nfields):
+ oid, length = _unpack_oidlen(data, i)
+ yield oid, i + 8, length
+ i += (8 + length) if length > 0 else 8
+
+ def _config_types(self, data: Buffer) -> None:
+ oids = [r[0] for r in self._walk_record(data)]
+ self._tx.set_loader_types(oids, self.format)
+
+
+class CompositeLoader(RecordLoader):
+
+ factory: Callable[..., Any]
+ fields_types: List[int]
+ _types_set = False
+
+ def load(self, data: Buffer) -> Any:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ if data == b"()":
+ return type(self).factory()
+
+ return type(self).factory(
+ *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
+ )
+
+ def _config_types(self, data: Buffer) -> None:
+ self._tx.set_loader_types(self.fields_types, self.format)
+
+
+class CompositeBinaryLoader(RecordBinaryLoader):
+
+ format = pq.Format.BINARY
+ factory: Callable[..., Any]
+
+ def load(self, data: Buffer) -> Any:
+ r = super().load(data)
+ return type(self).factory(*r)
+
+
+def register_composite(
+ info: CompositeInfo,
+ context: Optional[AdaptContext] = None,
+ factory: Optional[Callable[..., Any]] = None,
+) -> None:
+ """Register the adapters to load and dump a composite type.
+
+ :param info: The object with the information about the composite to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param factory: Callable to convert the sequence of attributes read from
+ the composite into a Python object.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested composite available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ if not factory:
+ factory = namedtuple( # type: ignore
+ _as_python_identifier(info.name),
+ [_as_python_identifier(n) for n in info.field_names],
+ )
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[BaseCompositeLoader] = type(
+ f"{info.name.title()}Loader",
+ (CompositeLoader,),
+ {
+ "factory": factory,
+ "fields_types": info.field_types,
+ },
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ loader = type(
+ f"{info.name.title()}BinaryLoader",
+ (CompositeBinaryLoader,),
+ {"factory": factory},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # If the factory is a type, create and register dumpers for it
+ if isinstance(factory, type):
+ dumper = type(
+ f"{info.name.title()}BinaryDumper",
+ (TupleBinaryDumper,),
+ {"oid": info.oid, "info": info},
+ )
+ adapters.register_dumper(factory, dumper)
+
+ # Default to the text dumper because it is more flexible
+ dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid})
+ adapters.register_dumper(factory, dumper)
+
+ info.python_type = factory
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(tuple, TupleDumper)
+ adapters.register_loader("record", RecordLoader)
+ adapters.register_loader("record", RecordBinaryLoader)
diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py
new file mode 100644
index 0000000..f0dfe83
--- /dev/null
+++ b/psycopg/psycopg/types/datetime.py
@@ -0,0 +1,754 @@
+"""
+Adapters for date/time types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import struct
+from datetime import date, datetime, time, timedelta, timezone
+from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from .._tz import get_tzinfo
+from ..abc import AdaptContext, DumperKey
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from ..errors import InterfaceError, DataError
+from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
+
+if TYPE_CHECKING:
+ from ..connection import BaseConnection
+
+_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
+_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
+_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
+
+_struct_interval = struct.Struct("!qii") # microseconds, days, months
+_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
+_unpack_interval = cast(
+ Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
+)
+
+utc = timezone.utc
+_pg_date_epoch_days = date(2000, 1, 1).toordinal()
+_pg_datetime_epoch = datetime(2000, 1, 1)
+_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
+_py_date_min_days = date.min.toordinal()
+
+
+class DateDumper(Dumper):
+
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DateBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["date"].oid
+
+ def dump(self, obj: date) -> bytes:
+ days = obj.toordinal() - _pg_date_epoch_days
+ return pack_int4(days)
+
+
+class _BaseTimeDumper(Dumper):
+ def get_key(self, obj: time, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseTimeTextDumper(_BaseTimeDumper):
+ def dump(self, obj: time) -> bytes:
+ return str(obj).encode()
+
+
+class TimeDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["time"].oid
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+class TimeTzDumper(_BaseTimeTextDumper):
+
+ oid = postgres.types["timetz"].oid
+
+
+class TimeBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["time"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ return pack_int8(us)
+
+ def upgrade(self, obj: time, format: PyFormat) -> Dumper:
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzBinaryDumper(self.cls)
+
+
+class TimeTzBinaryDumper(_BaseTimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timetz"].oid
+
+ def dump(self, obj: time) -> bytes:
+ us = obj.microsecond + 1_000_000 * (
+ obj.second + 60 * (obj.minute + 60 * obj.hour)
+ )
+ off = obj.utcoffset()
+ assert off is not None
+ return _pack_timetz(us, -int(off.total_seconds()))
+
+
+class _BaseDatetimeDumper(Dumper):
+ def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ raise NotImplementedError
+
+
+class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
+ def dump(self, obj: datetime) -> bytes:
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ return str(obj).encode()
+
+
+class DatetimeDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamptz"].oid
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzDumper(self.cls)
+
+
+class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
+
+ oid = postgres.types["timestamp"].oid
+
+
+class DatetimeBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamptz"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetimetz_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+ def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzBinaryDumper(self.cls)
+
+
+class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["timestamp"].oid
+
+ def dump(self, obj: datetime) -> bytes:
+ delta = obj - _pg_datetime_epoch
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
+ return pack_int8(micros)
+
+
+class TimedeltaDumper(Dumper):
+
+ oid = postgres.types["interval"].oid
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ if self.connection:
+ if (
+ self.connection.pgconn.parameter_status(b"IntervalStyle")
+ == b"sql_standard"
+ ):
+ setattr(self, "dump", self._dump_sql)
+
+ def dump(self, obj: timedelta) -> bytes:
+ # The comma is parsed ok by PostgreSQL but it's not documented
+ # and it seems brittle to rely on it. CRDB doesn't consume it well.
+ return str(obj).encode().replace(b",", b"")
+
+ def _dump_sql(self, obj: timedelta) -> bytes:
+ # sql_standard format needs explicit signs
+ # otherwise -1 day 1 sec will mean -1 sec
+ return b"%+d day %+d second %+d microsecond" % (
+ obj.days,
+ obj.seconds,
+ obj.microseconds,
+ )
+
+
+class TimedeltaBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["interval"].oid
+
+ def dump(self, obj: timedelta) -> bytes:
+ micros = 1_000_000 * obj.seconds + obj.microseconds
+ return _pack_interval(micros, obj.days, 0)
+
+
+class DateLoader(Loader):
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> date:
+ if self._order == self._ORDER_YMD:
+ ye = data[:4]
+ mo = data[5:7]
+ da = data[8:]
+ elif self._order == self._ORDER_DMY:
+ da = data[:2]
+ mo = data[3:5]
+ ye = data[6:]
+ else:
+ mo = data[:2]
+ da = data[3:5]
+ ye = data[6:]
+
+ try:
+ return date(int(ye), int(mo), int(da))
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ if s == "infinity" or (s and len(s.split()[0]) > 10):
+ raise DataError(f"date too large (after year 10K): {s!r}") from None
+ elif s == "-infinity" or "BC" in s:
+ raise DataError(f"date too small (before year 1): {s!r}") from None
+ else:
+ raise DataError(f"can't parse date {s!r}: {ex}") from None
+
+
+class DateBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> date:
+ days = unpack_int4(data)[0] + _pg_date_epoch_days
+ try:
+ return date.fromordinal(days)
+ except (ValueError, OverflowError):
+ if days < _py_date_min_days:
+ raise DataError("date too small (before year 1)") from None
+ else:
+ raise DataError("date too large (after year 10K)") from None
+
+
+class TimeLoader(Loader):
+
+ _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}")
+
+ ho, mi, se, fr = m.groups()
+
+ # Pad the fraction of second to get micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return time(int(ho), int(mi), int(se), us)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse time {s!r}: {e}") from None
+
+
+class TimeBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val = unpack_int8(data)[0]
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ try:
+ return time(h, m, s, us)
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimetzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def load(self, data: Buffer) -> time:
+ m = self._re_format.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}")
+
+ ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone
+ off = 60 * 60 * int(oh)
+ if om:
+ off += 60 * int(om)
+ if os:
+ off += int(os)
+ tz = timezone(timedelta(0, off if sgn == b"+" else -off))
+
+ try:
+ return time(int(ho), int(mi), int(se), us, tz)
+ except ValueError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse timetz {s!r}: {e}") from None
+
+
+class TimetzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> time:
+ val, off = _unpack_timetz(data)
+
+ val, us = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+
+ try:
+ return time(h, m, s, us, timezone(timedelta(seconds=-off)))
+ except ValueError:
+ raise DataError(f"time not supported by Python: hour={h}") from None
+
+
+class TimestampLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ $
+ """
+ )
+ _re_format_pg = re.compile(
+ rb"""(?ix)
+ ^
+ [a-z]+ [^a-z0-9] # DoW, separator
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+|[a-z]+) [^a-z0-9] # Month or day
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ [^a-z0-9] (\d+) # Year
+ $
+ """
+ )
+
+ _ORDER_YMD = 0
+ _ORDER_DMY = 1
+ _ORDER_MDY = 2
+ _ORDER_PGDM = 3
+ _ORDER_PGMD = 4
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+
+ ds = _get_datestyle(self.connection)
+ if ds.startswith(b"I"): # ISO
+ self._order = self._ORDER_YMD
+ elif ds.startswith(b"G"): # German
+ self._order = self._ORDER_DMY
+ elif ds.startswith(b"S"): # SQL
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
+ elif ds.startswith(b"P"): # Postgres
+ self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
+ self._re_format = self._re_format_pg
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ if self._order == self._ORDER_YMD:
+ ye, mo, da, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_DMY:
+ da, mo, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ elif self._order == self._ORDER_MDY:
+ mo, da, ye, ho, mi, se, fr = m.groups()
+ imo = int(mo)
+ else:
+ if self._order == self._ORDER_PGDM:
+ da, mo, ho, mi, se, fr, ye = m.groups()
+ else:
+ mo, da, ho, mi, se, fr, ye = m.groups()
+ try:
+ imo = _month_abbr[mo]
+ except KeyError:
+ s = mo.decode("utf8", "replace")
+ raise DataError(f"can't parse month: {s!r}") from None
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ try:
+ return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+
+class TimestampBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ return _pg_datetime_epoch + timedelta(microseconds=micros)
+ except OverflowError:
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class TimestamptzLoader(Loader):
+
+ _re_format = re.compile(
+ rb"""(?ix)
+ ^
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date
+ (?: T | [^a-z0-9] ) # Separator, including T
+ (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time
+ (?: \.(\d+) )? # Micros
+ ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone
+ $
+ """
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ ds = _get_datestyle(self.connection)
+ if not ds.startswith(b"I"): # not ISO
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> datetime:
+ m = self._re_format.match(data)
+ if not m:
+ raise _get_timestamp_load_error(self.connection, data) from None
+
+ ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
+
+ # Pad the fraction of second to get the micros
+ if fr:
+ us = int(fr)
+ if len(fr) < 6:
+ us *= _uspad[len(fr)]
+ else:
+ us = 0
+
+ # Calculate timezone offset
+ soff = 60 * 60 * int(oh)
+ if om:
+ soff += 60 * int(om)
+ if os:
+ soff += int(os)
+ tzoff = timedelta(0, soff if sgn == b"+" else -soff)
+
+ # The return value is a datetime with the timezone of the connection
+ # (in order to be consistent with the binary loader, which is the only
+ # thing it can return). So create a temporary datetime object, in utc,
+ # shift it by the offset parsed from the timestamp, and then move it to
+ # the connection timezone.
+ dt = None
+ ex: Exception
+ try:
+ dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
+ return (dt - tzoff).astimezone(self._timezone)
+ except OverflowError as e:
+ # If we have created the temporary 'dt' it means that we have a
+ # datetime close to max, the shift pushed it past max, overflowing.
+ # In this case return the datetime in a fixed offset timezone.
+ if dt is not None:
+ return dt.replace(tzinfo=timezone(tzoff))
+ else:
+ ex = e
+ except ValueError as e:
+ ex = e
+
+ raise _get_timestamp_load_error(self.connection, data, ex) from None
+
+ def _load_notimpl(self, data: Buffer) -> datetime:
+ s = bytes(data).decode("utf8", "replace")
+ ds = _get_datestyle(self.connection).decode("ascii")
+ raise NotImplementedError(
+ f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
+ )
+
+
+class TimestamptzBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
+
+ def load(self, data: Buffer) -> datetime:
+ micros = unpack_int8(data)[0]
+ try:
+ ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
+ return ts.astimezone(self._timezone)
+ except OverflowError:
+ # If we were asked about a timestamp which would overflow in UTC,
+ # but not in the desired timezone (e.g. datetime.max at Chicago
+ # timezone) we can still save the day by shifting the value by the
+ # timezone offset and then replacing the timezone.
+ if self._timezone:
+ utcoff = self._timezone.utcoffset(
+ datetime.min if micros < 0 else datetime.max
+ )
+ if utcoff:
+ usoff = 1_000_000 * int(utcoff.total_seconds())
+ try:
+ ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
+ except OverflowError:
+ pass # will raise downstream
+ else:
+ return ts.replace(tzinfo=self._timezone)
+
+ if micros <= 0:
+ raise DataError("timestamp too small (before year 1)") from None
+ else:
+ raise DataError("timestamp too large (after year 10K)") from None
+
+
+class IntervalLoader(Loader):
+
+ _re_interval = re.compile(
+ rb"""
+ (?: ([-+]?\d+) \s+ years? \s* )? # Years
+ (?: ([-+]?\d+) \s+ mons? \s* )? # Months
+ (?: ([-+]?\d+) \s+ days? \s* )? # Days
+ (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time
+ )?
+ """,
+ re.VERBOSE,
+ )
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if self.connection:
+ ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
+ if ints != b"postgres":
+ setattr(self, "load", self._load_notimpl)
+
+ def load(self, data: Buffer) -> timedelta:
+ m = self._re_interval.match(data)
+ if not m:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}")
+
+ ye, mo, da, sgn, ho, mi, se = m.groups()
+ days = 0
+ seconds = 0.0
+
+ if ye:
+ days += 365 * int(ye)
+ if mo:
+ days += 30 * int(mo)
+ if da:
+ days += int(da)
+
+ if ho:
+ seconds = 3600 * int(ho) + 60 * int(mi) + float(se)
+ if sgn == b"-":
+ seconds = -seconds
+
+ try:
+ return timedelta(days=days, seconds=seconds)
+ except OverflowError as e:
+ s = bytes(data).decode("utf8", "replace")
+ raise DataError(f"can't parse interval {s!r}: {e}") from None
+
+ def _load_notimpl(self, data: Buffer) -> timedelta:
+ s = bytes(data).decode("utf8", "replace")
+ ints = (
+ self.connection
+ and self.connection.pgconn.parameter_status(b"IntervalStyle")
+ or b"unknown"
+ ).decode("utf8", "replace")
+ raise NotImplementedError(
+ f"can't parse interval with IntervalStyle {ints}: {s!r}"
+ )
+
+
+class IntervalBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> timedelta:
+ micros, days, months = _unpack_interval(data)
+ if months > 0:
+ years, months = divmod(months, 12)
+ days = days + 30 * months + 365 * years
+ elif months < 0:
+ years, months = divmod(-months, 12)
+ days = days - 30 * months - 365 * years
+
+ try:
+ return timedelta(days=days, microseconds=micros)
+ except OverflowError as e:
+ raise DataError(f"can't parse interval: {e}") from None
+
+
+def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
+ if conn:
+ ds = conn.pgconn.parameter_status(b"DateStyle")
+ if ds:
+ return ds
+
+ return b"ISO, DMY"
+
+
+def _get_timestamp_load_error(
+ conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
+) -> Exception:
+ s = bytes(data).decode("utf8", "replace")
+
+ def is_overflow(s: str) -> bool:
+ if not s:
+ return False
+
+ ds = _get_datestyle(conn)
+ if not ds.startswith(b"P"): # Postgres
+ return len(s.split()[0]) > 10 # date is first token
+ else:
+ return len(s.split()[-1]) > 4 # year is last token
+
+ if s == "-infinity" or s.endswith("BC"):
+ return DataError("timestamp too small (before year 1): {s!r}")
+ elif s == "infinity" or is_overflow(s):
+ return DataError(f"timestamp too large (after year 10K): {s!r}")
+ else:
+ return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
+
+
+_month_abbr = {
+ n: i
+ for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
+}
+
+# Pad to get microseconds from a fraction of seconds
+_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1]
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("datetime.date", DateDumper)
+ adapters.register_dumper("datetime.date", DateBinaryDumper)
+
+ # first register dumpers for 'timetz' oid, then the proper ones on time type.
+ adapters.register_dumper("datetime.time", TimeTzDumper)
+ adapters.register_dumper("datetime.time", TimeTzBinaryDumper)
+ adapters.register_dumper("datetime.time", TimeDumper)
+ adapters.register_dumper("datetime.time", TimeBinaryDumper)
+
+ # first register dumpers for 'timestamp' oid, then the proper ones
+ # on the datetime type.
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeDumper)
+ adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper)
+
+ adapters.register_dumper("datetime.timedelta", TimedeltaDumper)
+ adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper)
+
+ adapters.register_loader("date", DateLoader)
+ adapters.register_loader("date", DateBinaryLoader)
+ adapters.register_loader("time", TimeLoader)
+ adapters.register_loader("time", TimeBinaryLoader)
+ adapters.register_loader("timetz", TimetzLoader)
+ adapters.register_loader("timetz", TimetzBinaryLoader)
+ adapters.register_loader("timestamp", TimestampLoader)
+ adapters.register_loader("timestamp", TimestampBinaryLoader)
+ adapters.register_loader("timestamptz", TimestamptzLoader)
+ adapters.register_loader("timestamptz", TimestamptzBinaryLoader)
+ adapters.register_loader("interval", IntervalLoader)
+ adapters.register_loader("interval", IntervalBinaryLoader)
diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py
new file mode 100644
index 0000000..d3c7387
--- /dev/null
+++ b/psycopg/psycopg/types/enum.py
@@ -0,0 +1,177 @@
+"""
+Adapters for the enum type.
+"""
+from enum import Enum
+from typing import Any, Dict, Generic, Optional, Mapping, Sequence
+from typing import Tuple, Type, TypeVar, Union, cast
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from .._encodings import conn_encoding
+from .._typeinfo import EnumInfo as EnumInfo # exported here
+
+E = TypeVar("E", bound=Enum)
+
+EnumDumpMap: TypeAlias = Dict[E, bytes]
+EnumLoadMap: TypeAlias = Dict[bytes, E]
+EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
+
+
+class _BaseEnumLoader(Loader, Generic[E]):
+ """
+ Loader for a specific Enum class
+ """
+
+ enum: Type[E]
+ _load_map: EnumLoadMap[E]
+
+ def load(self, data: Buffer) -> E:
+ if not isinstance(data, bytes):
+ data = bytes(data)
+
+ try:
+ return self._load_map[data]
+ except KeyError:
+ enc = conn_encoding(self.connection)
+ label = data.decode(enc, "replace")
+ raise e.DataError(
+ f"bad member for enum {self.enum.__qualname__}: {label!r}"
+ )
+
+
+class _BaseEnumDumper(Dumper, Generic[E]):
+ """
+ Dumper for a specific Enum class
+ """
+
+ enum: Type[E]
+ _dump_map: EnumDumpMap[E]
+
+ def dump(self, value: E) -> Buffer:
+ return self._dump_map[value]
+
+
+class EnumDumper(Dumper):
+ """
+ Dumper for a generic Enum class
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._encoding = conn_encoding(self.connection)
+
+ def dump(self, value: E) -> Buffer:
+ return value.name.encode(self._encoding)
+
+
+class EnumBinaryDumper(EnumDumper):
+ format = Format.BINARY
+
+
+def register_enum(
+ info: EnumInfo,
+ context: Optional[AdaptContext] = None,
+ enum: Optional[Type[E]] = None,
+ *,
+ mapping: EnumMapping[E] = None,
+) -> None:
+ """Register the adapters to load and dump a enum type.
+
+ :param info: The object with the information about the enum to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+ :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
+ a new enum will be generated and exposed as `EnumInfo.enum`.
+ :param mapping: Override the mapping between `!enum` members and `!info`
+ labels.
+ """
+
+ if not info:
+ raise TypeError("no info passed. Is the requested enum available?")
+
+ if enum is None:
+ enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__))
+
+ info.enum = enum
+ adapters = context.adapters if context else postgres.adapters
+ info.register(context)
+
+ load_map = _make_load_map(info, enum, mapping, context)
+ attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
+
+ name = f"{info.name.title()}Loader"
+ loader = type(name, (_BaseEnumLoader,), attribs)
+ adapters.register_loader(info.oid, loader)
+
+ name = f"{info.name.title()}BinaryLoader"
+ loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
+ adapters.register_loader(info.oid, loader)
+
+ dump_map = _make_dump_map(info, enum, mapping, context)
+ attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
+
+ name = f"{enum.__name__}Dumper"
+ dumper = type(name, (_BaseEnumDumper,), attribs)
+ adapters.register_dumper(info.enum, dumper)
+
+ name = f"{enum.__name__}BinaryDumper"
+ dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY})
+ adapters.register_dumper(info.enum, dumper)
+
+
+def _make_load_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumLoadMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumLoadMap[E] = {}
+ for label in info.labels:
+ try:
+ member = enum[label]
+ except KeyError:
+ # tolerate a missing enum, assuming it won't be used. If it is we
+ # will get a DataError on fetch.
+ pass
+ else:
+ rv[label.encode(enc)] = member
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[label.encode(enc)] = member
+
+ return rv
+
+
+def _make_dump_map(
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumDumpMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumDumpMap[E] = {}
+ for member in enum:
+ rv[member] = member.name.encode(enc)
+
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[member] = label.encode(enc)
+
+ return rv
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, EnumBinaryDumper)
+ context.adapters.register_dumper(Enum, EnumDumper)
diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py
new file mode 100644
index 0000000..e1ab1d5
--- /dev/null
+++ b/psycopg/psycopg/types/hstore.py
@@ -0,0 +1,131 @@
+"""
+Dict to hstore adaptation
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+from typing import Dict, List, Optional
+from typing_extensions import TypeAlias
+
+from .. import errors as e
+from .. import postgres
+from ..abc import Buffer, AdaptContext
+from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..postgres import TEXT_OID
+from .._typeinfo import TypeInfo
+
+_re_escape = re.compile(r'(["\\])')
+_re_unescape = re.compile(r"\\(.)")
+
+_re_hstore = re.compile(
+ r"""
+ # hstore key:
+ # a string of normal or escaped chars
+ "((?: [^"\\] | \\. )*)"
+ \s*=>\s* # hstore value
+ (?:
+ NULL # the value can be null - not caught
+ # or a quoted string like the key
+ | "((?: [^"\\] | \\. )*)"
+ )
+ (?:\s*,\s*|$) # pairs separated by comma or end of string.
+""",
+ re.VERBOSE,
+)
+
+
+Hstore: TypeAlias = Dict[str, Optional[str]]
+
+
+class BaseHstoreDumper(RecursiveDumper):
+ def dump(self, obj: Hstore) -> Buffer:
+ if not obj:
+ return b""
+
+ tokens: List[str] = []
+
+ def add_token(s: str) -> None:
+ tokens.append('"')
+ tokens.append(_re_escape.sub(r"\\\1", s))
+ tokens.append('"')
+
+ for k, v in obj.items():
+
+ if not isinstance(k, str):
+ raise e.DataError("hstore keys can only be strings")
+ add_token(k)
+
+ tokens.append("=>")
+
+ if v is None:
+ tokens.append("NULL")
+ elif not isinstance(v, str):
+ raise e.DataError("hstore keys can only be strings")
+ else:
+ add_token(v)
+
+ tokens.append(",")
+
+ del tokens[-1]
+ data = "".join(tokens)
+ dumper = self._tx.get_dumper(data, PyFormat.TEXT)
+ return dumper.dump(data)
+
+
+class HstoreLoader(RecursiveLoader):
+ def load(self, data: Buffer) -> Hstore:
+ loader = self._tx.get_loader(TEXT_OID, self.format)
+ s: str = loader.load(data)
+
+ rv: Hstore = {}
+ start = 0
+ for m in _re_hstore.finditer(s):
+ if m is None or m.start() != start:
+ raise e.DataError(f"error parsing hstore pair at char {start}")
+ k = _re_unescape.sub(r"\1", m.group(1))
+ v = m.group(2)
+ if v is not None:
+ v = _re_unescape.sub(r"\1", v)
+
+ rv[k] = v
+ start = m.end()
+
+ if start < len(s):
+ raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
+
+ return rv
+
+
+def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump hstore.
+
+ :param info: The object with the information about the hstore type.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'hstore' extension loaded?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # Generate and register a customized text dumper
+ class HstoreDumper(BaseHstoreDumper):
+ oid = info.oid
+
+ adapters.register_dumper(dict, HstoreDumper)
+
+ # register the text loader on the oid
+ adapters.register_loader(info.oid, HstoreLoader)
diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py
new file mode 100644
index 0000000..a80e0e4
--- /dev/null
+++ b/psycopg/psycopg/types/json.py
@@ -0,0 +1,232 @@
+"""
+Adapers for JSON types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import json
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+
+from .. import abc
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
+from ..errors import DataError
+
+JsonDumpsFunction = Callable[[Any], str]
+JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
+
+
+def set_json_dumps(
+ dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON serialisation function to store JSON objects in the database.
+
+ :param dumps: The dump function to use.
+ :type dumps: `!Callable[[Any], str]`
+ :param context: Where to use the `!dumps` function. If not specified, use it
+ globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default dumping JSON uses the builtin `json.dumps`. You can override
+ it to use a different JSON library or to use customised arguments.
+
+ If the `Json` wrapper specified a `!dumps` function, use it in precedence
+ of the one set by this function.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonDumper._dumps = dumps
+ else:
+ adapters = context.adapters
+
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ (Json, PyFormat.BINARY),
+ (Json, PyFormat.TEXT),
+ (Jsonb, PyFormat.BINARY),
+ (Jsonb, PyFormat.TEXT),
+ ]
+ dumper: Type[_JsonDumper]
+ for wrapper, format in grid:
+ base = _get_current_dumper(adapters, wrapper, format)
+ name = base.__name__
+ if not base.__name__.startswith("Custom"):
+ name = f"Custom{name}"
+ dumper = type(name, (base,), {"_dumps": dumps})
+ adapters.register_dumper(wrapper, dumper)
+
+
+def set_json_loads(
+ loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
+) -> None:
+ """
+ Set the JSON parsing function to fetch JSON objects from the database.
+
+ :param loads: The load function to use.
+ :type loads: `!Callable[[bytes], Any]`
+ :param context: Where to use the `!loads` function. If not specified, use
+ it globally.
+ :type context: `~psycopg.Connection` or `~psycopg.Cursor`
+
+ By default loading JSON uses the builtin `json.loads`. You can override
+ it to use a different JSON library or to use customised arguments.
+ """
+ if context is None:
+ # If changing load function globally, just change the default on the
+ # global class
+ _JsonLoader._loads = loads
+ else:
+ # If the scope is smaller than global, create subclassess and register
+ # them in the appropriate scope.
+ grid = [
+ ("json", JsonLoader),
+ ("json", JsonBinaryLoader),
+ ("jsonb", JsonbLoader),
+ ("jsonb", JsonbBinaryLoader),
+ ]
+ loader: Type[_JsonLoader]
+ for tname, base in grid:
+ loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads})
+ context.adapters.register_loader(tname, loader)
+
+
+class _JsonWrapper:
+ __slots__ = ("obj", "dumps")
+
+ def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
+ self.obj = obj
+ self.dumps = dumps
+
+ def __repr__(self) -> str:
+ sobj = repr(self.obj)
+ if len(sobj) > 40:
+ sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
+ return f"{self.__class__.__name__}({sobj})"
+
+
+class Json(_JsonWrapper):
+ __slots__ = ()
+
+
+class Jsonb(_JsonWrapper):
+ __slots__ = ()
+
+
+class _JsonDumper(Dumper):
+
+ # The globally used JSON dumps() function. It can be changed globally (by
+ # set_json_dumps) or by a subclass.
+ _dumps: JsonDumpsFunction = json.dumps
+
+ def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
+ super().__init__(cls, context)
+ self.dumps = self.__class__._dumps
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return dumps(obj.obj).encode()
+
+
+class JsonDumper(_JsonDumper):
+
+ oid = postgres.types["json"].oid
+
+
+class JsonBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["json"].oid
+
+
+class JsonbDumper(_JsonDumper):
+
+ oid = postgres.types["jsonb"].oid
+
+
+class JsonbBinaryDumper(_JsonDumper):
+
+ format = Format.BINARY
+ oid = postgres.types["jsonb"].oid
+
+ def dump(self, obj: _JsonWrapper) -> bytes:
+ dumps = obj.dumps or self.dumps
+ return b"\x01" + dumps(obj.obj).encode()
+
+
+class _JsonLoader(Loader):
+
+ # The globally used JSON loads() function. It can be changed globally (by
+ # set_json_loads) or by a subclass.
+ _loads: JsonLoadsFunction = json.loads
+
+ def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
+ super().__init__(oid, context)
+ self.loads = self.__class__._loads
+
+ def load(self, data: Buffer) -> Any:
+ # json.loads() cannot work on memoryview.
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+class JsonLoader(_JsonLoader):
+ pass
+
+
+class JsonbLoader(_JsonLoader):
+ pass
+
+
+class JsonBinaryLoader(_JsonLoader):
+ format = Format.BINARY
+
+
+class JsonbBinaryLoader(_JsonLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Any:
+ if data and data[0] != 1:
+ raise DataError("unknown jsonb binary format: {data[0]}")
+ data = data[1:]
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return self.loads(data)
+
+
+def _get_current_dumper(
+ adapters: AdaptersMap, cls: type, format: PyFormat
+) -> Type[abc.Dumper]:
+ try:
+ return adapters.get_dumper(cls, format)
+ except e.ProgrammingError:
+ return _default_dumpers[cls, format]
+
+
+_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
+ (Json, PyFormat.BINARY): JsonBinaryDumper,
+ (Json, PyFormat.TEXT): JsonDumper,
+ (Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
+ (Jsonb, PyFormat.TEXT): JsonDumper,
+}
+
+
+def register_default_adapters(context: abc.AdaptContext) -> None:
+ adapters = context.adapters
+
+ # Currently json binary format is nothing different than text, maybe with
+ # an extra memcopy we can avoid.
+ adapters.register_dumper(Json, JsonBinaryDumper)
+ adapters.register_dumper(Json, JsonDumper)
+ adapters.register_dumper(Jsonb, JsonbBinaryDumper)
+ adapters.register_dumper(Jsonb, JsonbDumper)
+ adapters.register_loader("json", JsonLoader)
+ adapters.register_loader("jsonb", JsonbLoader)
+ adapters.register_loader("json", JsonBinaryLoader)
+ adapters.register_loader("jsonb", JsonbBinaryLoader)
diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py
new file mode 100644
index 0000000..3eaa7f1
--- /dev/null
+++ b/psycopg/psycopg/types/multirange.py
@@ -0,0 +1,514 @@
+"""
+Support for multirange types adaptation.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from decimal import Decimal
+from typing import Any, Generic, List, Iterable
+from typing import MutableSequence, Optional, Type, Union, overload
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
+
+from .range import Range, T, load_range_text, load_range_binary
+from .range import dump_range_text, dump_range_binary, fail_dump
+
+
+class Multirange(MutableSequence[Range[T]]):
+ """Python representation for a PostgreSQL multirange type.
+
+ :param items: Sequence of ranges to initialise the object.
+ """
+
+ def __init__(self, items: Iterable[Range[T]] = ()):
+ self._ranges: List[Range[T]] = list(map(self._check_type, items))
+
+ def _check_type(self, item: Any) -> Range[Any]:
+ if not isinstance(item, Range):
+ raise TypeError(
+ f"Multirange is a sequence of Range, got {type(item).__name__}"
+ )
+ return item
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._ranges!r})"
+
+ def __str__(self) -> str:
+ return f"{{{', '.join(map(str, self._ranges))}}}"
+
+ @overload
+ def __getitem__(self, index: int) -> Range[T]:
+ ...
+
+ @overload
+ def __getitem__(self, index: slice) -> "Multirange[T]":
+ ...
+
+ def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
+ if isinstance(index, int):
+ return self._ranges[index]
+ else:
+ return Multirange(self._ranges[index])
+
+ def __len__(self) -> int:
+ return len(self._ranges)
+
+ @overload
+ def __setitem__(self, index: int, value: Range[T]) -> None:
+ ...
+
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
+ ...
+
+ def __setitem__(
+ self,
+ index: Union[int, slice],
+ value: Union[Range[T], Iterable[Range[T]]],
+ ) -> None:
+ if isinstance(index, int):
+ self._check_type(value)
+ self._ranges[index] = self._check_type(value)
+ elif not isinstance(value, Iterable):
+ raise TypeError("can only assign an iterable")
+ else:
+ value = map(self._check_type, value)
+ self._ranges[index] = value
+
+ def __delitem__(self, index: Union[int, slice]) -> None:
+ del self._ranges[index]
+
+ def insert(self, index: int, value: Range[T]) -> None:
+ self._ranges.insert(index, self._check_type(value))
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return False
+ return self._ranges == other._ranges
+
+ # Order is arbitrary but consistent
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges < other._ranges
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if not isinstance(other, Multirange):
+ return NotImplemented
+ return self._ranges > other._ranges
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+
+# Subclasses to specify a specific subtype. Usually not needed
+
+
+class Int4Multirange(Multirange[int]):
+ pass
+
+
+class Int8Multirange(Multirange[int]):
+ pass
+
+
+class NumericMultirange(Multirange[Decimal]):
+ pass
+
+
+class DateMultirange(Multirange[date]):
+ pass
+
+
+class TimestampMultirange(Multirange[datetime]):
+ pass
+
+
+class TimestamptzMultirange(Multirange[datetime]):
+ pass
+
+
+class BaseMultirangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Multirange:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return MultirangeDumper(self.cls)
+
+ dumper: BaseMultirangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = MultirangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_multirange_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_multirange_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Multirange[Any]) -> Any:
+ """
+ Return a member representative of the multirange
+ """
+ for r in obj:
+ if r.lower is not None:
+ return r.lower
+ if r.upper is not None:
+ return r.upper
+ return None
+
+ def _get_multirange_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class MultirangeDumper(BaseMultirangeDumper):
+ """
+ Dumper for multirange types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ if not obj:
+ return b"{}"
+
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [b"{"]
+ for r in obj:
+ out.append(dump_range_text(r, dump))
+ out.append(b",")
+ out[-1] = b"}"
+ return b"".join(out)
+
+
+class MultirangeBinaryDumper(BaseMultirangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Multirange[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ out: List[Buffer] = [pack_len(len(obj))]
+ for r in obj:
+ data = dump_range_binary(r, dump)
+ out.append(pack_len(len(data)))
+ out.append(data)
+ return b"".join(out)
+
+
+class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class MultirangeLoader(BaseMultirangeLoader[T]):
+ def load(self, data: Buffer) -> Multirange[T]:
+ if not data or data[0] != _START_INT:
+ raise e.DataError(
+ "malformed multirange starting with"
+ f" {bytes(data[:1]).decode('utf8', 'replace')}"
+ )
+
+ out = Multirange[T]()
+ if data == b"{}":
+ return out
+
+ pos = 1
+ data = data[pos:]
+ try:
+ while True:
+ r, pos = load_range_text(data, self._load)
+ out.append(r)
+
+ sep = data[pos] # can raise IndexError
+ if sep == _SEP_INT:
+ data = data[pos + 1 :]
+ continue
+ elif sep == _END_INT:
+ if len(data) == pos + 1:
+ return out
+ else:
+ raise e.DataError(
+ "malformed multirange: data after closing brace"
+ )
+ else:
+ raise e.DataError(
+ f"malformed multirange: found unexpected {chr(sep)}"
+ )
+
+ except IndexError:
+ raise e.DataError("malformed multirange: separator missing")
+
+ return out
+
+
+_SEP_INT = ord(",")
+_START_INT = ord("{")
+_END_INT = ord("}")
+
+
+class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Multirange[T]:
+ nelems = unpack_len(data, 0)[0]
+ pos = 4
+ out = Multirange[T]()
+ for i in range(nelems):
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ out.append(load_range_binary(data[pos : pos + length], self._load))
+ pos += length
+
+ if pos != len(data):
+ raise e.DataError("unexpected trailing data in multirange")
+
+ return out
+
+
+def register_multirange(
+ info: MultirangeInfo, context: Optional[AdaptContext] = None
+) -> None:
+ """Register the adapters to load and dump a multirange type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested multirange available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[MultirangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (MultirangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[MultirangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (MultirangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeDumper(MultirangeDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeDumper(MultirangeDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Binary dumpers for builtin multirange types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int4multirange"].oid
+
+
+class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["int8multirange"].oid
+
+
+class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["nummultirange"].oid
+
+
+class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["datemultirange"].oid
+
+
+class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tsmultirange"].oid
+
+
+class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
+ oid = postgres.types["tstzmultirange"].oid
+
+
+# Text loaders for builtin multirange types
+
+
+class Int4MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeLoader(MultirangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeLoader(MultirangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeLoader(MultirangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin multirange types
+
+
+class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Multirange, MultirangeBinaryDumper)
+ adapters.register_dumper(Multirange, MultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
+ adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
+ adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
+ adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
+ adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
+ adapters.register_loader("int4multirange", Int4MultirangeLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeLoader)
+ adapters.register_loader("datemultirange", DateMultirangeLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
+ adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
+ adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
+ adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
+ adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
+ adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py
new file mode 100644
index 0000000..2f2c05b
--- /dev/null
+++ b/psycopg/psycopg/types/net.py
@@ -0,0 +1,206 @@
+"""
+Adapters for network types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, Type, Union, TYPE_CHECKING
+from typing_extensions import TypeAlias
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import ipaddress
+
+Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"]
+Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"]
+Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"]
+
+# These objects will be imported lazily
+ip_address: Callable[[str], Address] = None # type: ignore[assignment]
+ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
+ip_network: Callable[[str], Network] = None # type: ignore[assignment]
+IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
+IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
+IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
+IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
+IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
+IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
+
+PGSQL_AF_INET = 2
+PGSQL_AF_INET6 = 3
+IPV4_PREFIXLEN = 32
+IPV6_PREFIXLEN = 128
+
+
+class _LazyIpaddress:
+ def _ensure_module(self) -> None:
+ global ip_address, ip_interface, ip_network
+ global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface
+ global IPv4Network, IPv6Network
+
+ if ip_address is None:
+ from ipaddress import ip_address, ip_interface, ip_network
+ from ipaddress import IPv4Address, IPv6Address
+ from ipaddress import IPv4Interface, IPv6Interface
+ from ipaddress import IPv4Network, IPv6Network
+
+
+class InterfaceDumper(Dumper):
+
+ oid = postgres.types["inet"].oid
+
+ def dump(self, obj: Interface) -> bytes:
+ return str(obj).encode()
+
+
+class NetworkDumper(Dumper):
+
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ return str(obj).encode()
+
+
+class _AIBinaryDumper(Dumper):
+ format = Format.BINARY
+ oid = postgres.types["inet"].oid
+
+
+class AddressBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Address) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.max_prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InterfaceBinaryDumper(_AIBinaryDumper):
+ def dump(self, obj: Interface) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.network.prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
+ """Either an address or an interface to inet
+
+ Used when looking up by oid.
+ """
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._ensure_module()
+
+ def dump(self, obj: Union[Address, Interface]) -> bytes:
+ packed = obj.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ if isinstance(obj, (IPv4Interface, IPv6Interface)):
+ prefixlen = obj.network.prefixlen
+ else:
+ prefixlen = obj.max_prefixlen
+
+ head = bytes((family, prefixlen, 0, len(packed)))
+ return head + packed
+
+
+class NetworkBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["cidr"].oid
+
+ def dump(self, obj: Network) -> bytes:
+ packed = obj.network_address.packed
+ family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
+ head = bytes((family, obj.prefixlen, 1, len(packed)))
+ return head + packed
+
+
+class _LazyIpaddressLoader(Loader, _LazyIpaddress):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._ensure_module()
+
+
+class InetLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ if b"/" in data:
+ return ip_interface(data.decode())
+ else:
+ return ip_address(data.decode())
+
+
+class InetBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ if prefix == IPV4_PREFIXLEN:
+ return IPv4Address(packed)
+ else:
+ return IPv4Interface((packed, prefix))
+ else:
+ if prefix == IPV6_PREFIXLEN:
+ return IPv6Address(packed)
+ else:
+ return IPv6Interface((packed, prefix))
+
+
+class CidrLoader(_LazyIpaddressLoader):
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ return ip_network(data.decode())
+
+
+class CidrBinaryLoader(_LazyIpaddressLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
+ prefix = data[1]
+ packed = data[4:]
+ if data[0] == PGSQL_AF_INET:
+ return IPv4Network((packed, prefix))
+ else:
+ return IPv6Network((packed, prefix))
+
+ return ip_network(data.decode())
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper)
+ adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper)
+ adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper)
+ adapters.register_dumper(None, InetBinaryDumper)
+ adapters.register_loader("inet", InetLoader)
+ adapters.register_loader("inet", InetBinaryLoader)
+ adapters.register_loader("cidr", CidrLoader)
+ adapters.register_loader("cidr", CidrBinaryLoader)
diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py
new file mode 100644
index 0000000..2ab857c
--- /dev/null
+++ b/psycopg/psycopg/types/none.py
@@ -0,0 +1,25 @@
+"""
+Adapters for None.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from ..abc import AdaptContext, NoneType
+from ..adapt import Dumper
+
+
+class NoneDumper(Dumper):
+ """
+ Not a complete dumper as it doesn't implement dump(), but it implements
+ quote(), so it can be used in sql composition.
+ """
+
+ def dump(self, obj: None) -> bytes:
+ raise NotImplementedError("NULL is passed to Postgres in other ways")
+
+ def quote(self, obj: None) -> bytes:
+ return b"NULL"
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(NoneType, NoneDumper)
diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py
new file mode 100644
index 0000000..1bd9329
--- /dev/null
+++ b/psycopg/psycopg/types/numeric.py
@@ -0,0 +1,515 @@
+"""
+Adapers for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import struct
+from math import log
+from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
+from decimal import Decimal, DefaultContext, Context
+
+from .. import postgres
+from .. import errors as e
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader, PyFormat
+from .._struct import pack_int2, pack_uint2, unpack_int2
+from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
+from .._struct import pack_int8, unpack_int8
+from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
+
+# Exposed here
+from .._wrappers import (
+ Int2 as Int2,
+ Int4 as Int4,
+ Int8 as Int8,
+ IntNumeric as IntNumeric,
+ Oid as Oid,
+ Float4 as Float4,
+ Float8 as Float8,
+)
+
+
+class _IntDumper(Dumper):
+ def dump(self, obj: Any) -> Buffer:
+ t = type(obj)
+ if t is not int:
+ # Convert to int in order to dump IntEnum correctly
+ if issubclass(t, int):
+ obj = int(obj)
+ else:
+ raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> Buffer:
+ value = self.dump(obj)
+ return value if obj >= 0 else b" " + value
+
+
+class _SpecialValuesDumper(Dumper):
+
+ _special: Dict[bytes, bytes] = {}
+
+ def dump(self, obj: Any) -> bytes:
+ return str(obj).encode()
+
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+
+ if value in self._special:
+ return self._special[value]
+
+ return value if obj >= 0 else b" " + value
+
+
+class FloatDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["float8"].oid
+
+ _special = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+ }
+
+
+class Float4Dumper(FloatDumper):
+ oid = postgres.types["float4"].oid
+
+
+class FloatBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["float8"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float8(obj)
+
+
+class Float4BinaryDumper(FloatBinaryDumper):
+
+ oid = postgres.types["float4"].oid
+
+ def dump(self, obj: float) -> bytes:
+ return pack_float4(obj)
+
+
+class DecimalDumper(_SpecialValuesDumper):
+
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> bytes:
+ if obj.is_nan():
+ # cover NaN and sNaN
+ return b"NaN"
+ else:
+ return str(obj).encode()
+
+ _special = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+ }
+
+
+class Int2Dumper(_IntDumper):
+ oid = postgres.types["int2"].oid
+
+
+class Int4Dumper(_IntDumper):
+ oid = postgres.types["int4"].oid
+
+
+class Int8Dumper(_IntDumper):
+ oid = postgres.types["int8"].oid
+
+
+class IntNumericDumper(_IntDumper):
+ oid = postgres.types["numeric"].oid
+
+
+class OidDumper(_IntDumper):
+ oid = postgres.types["oid"].oid
+
+
+class IntDumper(Dumper):
+ def dump(self, obj: Any) -> bytes:
+ raise TypeError(
+ f"{type(self).__name__} is a dispatcher to other dumpers:"
+ " dump() is not supposed to be called"
+ )
+
+ def get_key(self, obj: int, format: PyFormat) -> type:
+ return self.upgrade(obj, format).cls
+
+ _int2_dumper = Int2Dumper(Int2)
+ _int4_dumper = Int4Dumper(Int4)
+ _int8_dumper = Int8Dumper(Int8)
+ _int_numeric_dumper = IntNumericDumper(IntNumeric)
+
+ def upgrade(self, obj: int, format: PyFormat) -> Dumper:
+ if -(2**31) <= obj < 2**31:
+ if -(2**15) <= obj < 2**15:
+ return self._int2_dumper
+ else:
+ return self._int4_dumper
+ else:
+ if -(2**63) <= obj < 2**63:
+ return self._int8_dumper
+ else:
+ return self._int_numeric_dumper
+
+
+class Int2BinaryDumper(Int2Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int2(obj)
+
+
+class Int4BinaryDumper(Int4Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int4(obj)
+
+
+class Int8BinaryDumper(Int8Dumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_int8(obj)
+
+
+# Ratio between number of bits required to store a number and number of pg
+# decimal digits required.
+BIT_PER_PGDIGIT = log(2) / log(10_000)
+
+
+class IntNumericBinaryDumper(IntNumericDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> Buffer:
+ return dump_int_to_numeric_binary(obj)
+
+
+class OidBinaryDumper(OidDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ return pack_uint4(obj)
+
+
+class IntBinaryDumper(IntDumper):
+
+ format = Format.BINARY
+
+ _int2_dumper = Int2BinaryDumper(Int2)
+ _int4_dumper = Int4BinaryDumper(Int4)
+ _int8_dumper = Int8BinaryDumper(Int8)
+ _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
+
+
+class IntLoader(Loader):
+ def load(self, data: Buffer) -> int:
+ # it supports bytes directly
+ return int(data)
+
+
+class Int2BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int2(data)[0]
+
+
+class Int4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int4(data)[0]
+
+
+class Int8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_int8(data)[0]
+
+
+class OidBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> int:
+ return unpack_uint4(data)[0]
+
+
+class FloatLoader(Loader):
+ def load(self, data: Buffer) -> float:
+ # it supports bytes directly
+ return float(data)
+
+
+class Float4BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float4(data)[0]
+
+
+class Float8BinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> float:
+ return unpack_float8(data)[0]
+
+
+class NumericLoader(Loader):
+ def load(self, data: Buffer) -> Decimal:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return Decimal(data.decode())
+
+
+DEC_DIGITS = 4 # decimal digits per Postgres "digit"
+NUMERIC_POS = 0x0000
+NUMERIC_NEG = 0x4000
+NUMERIC_NAN = 0xC000
+NUMERIC_PINF = 0xD000
+NUMERIC_NINF = 0xF000
+
+_decimal_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+
+class _ContextMap(DefaultDict[int, Context]):
+ """
+ Cache for decimal contexts to use when the precision requires it.
+
+ Note: if the default context is used (prec=28) you can get an invalid
+ operation or a rounding to 0:
+
+ - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
+ - Decimal(1000).shift(25) = Decimal('0')
+ - Decimal(1000).shift(30) raises InvalidOperation
+ """
+
+ def __missing__(self, key: int) -> Context:
+ val = Context(prec=key)
+ self[key] = val
+ return val
+
+
+_contexts = _ContextMap()
+for i in range(DefaultContext.prec):
+ _contexts[i] = DefaultContext
+
+_unpack_numeric_head = cast(
+ Callable[[Buffer], Tuple[int, int, int, int]],
+ struct.Struct("!HhHH").unpack_from,
+)
+_pack_numeric_head = cast(
+ Callable[[int, int, int, int], bytes],
+ struct.Struct("!HhHH").pack,
+)
+
+
+class NumericBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Decimal:
+ ndigits, weight, sign, dscale = _unpack_numeric_head(data)
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ val = 0
+ for i in range(8, len(data), 2):
+ val = val * 10_000 + data[i] * 0x100 + data[i + 1]
+
+ shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+ ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, ctx)
+ .shift(shift, ctx)
+ )
+ else:
+ try:
+ return _decimal_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
+
+
+NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
+NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
+NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
+
+
+class DecimalBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Decimal) -> Buffer:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+class NumericDumper(DecimalDumper):
+ def dump(self, obj: Union[Decimal, int]) -> bytes:
+ if isinstance(obj, int):
+ return str(obj).encode()
+ else:
+ return super().dump(obj)
+
+
+class NumericBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["numeric"].oid
+
+ def dump(self, obj: Union[Decimal, int]) -> Buffer:
+ if isinstance(obj, int):
+ return dump_int_to_numeric_binary(obj)
+ else:
+ return dump_decimal_to_numeric_binary(obj)
+
+
+def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
+ sign, digits, exp = obj.as_tuple()
+ if exp == "n" or exp == "N": # type: ignore[comparison-overlap]
+ return NUMERIC_NAN_BIN
+ elif exp == "F": # type: ignore[comparison-overlap]
+ return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
+
+ # Weights of py digits into a pg digit according to their positions.
+ # Starting with an index wi != 0 is equivalent to prepending 0's to
+ # the digits tuple, but without really changing it.
+ weights = (1000, 100, 10, 1)
+ wi = 0
+
+ ndigits = nzdigits = len(digits)
+
+ # Find the last nonzero digit
+ while nzdigits > 0 and digits[nzdigits - 1] == 0:
+ nzdigits -= 1
+
+ if exp <= 0:
+ dscale = -exp
+ else:
+ dscale = 0
+ # align the py digits to the pg digits if there's some py exponent
+ ndigits += exp % DEC_DIGITS
+
+ if not nzdigits:
+ return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
+
+ # Equivalent of 0-padding left to align the py digits to the pg digits
+ # but without changing the digits tuple.
+ mod = (ndigits - dscale) % DEC_DIGITS
+ if mod:
+ wi = DEC_DIGITS - mod
+ ndigits += wi
+
+ tmp = nzdigits + wi
+ out = bytearray(
+ _pack_numeric_head(
+ tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
+ (ndigits + exp) // DEC_DIGITS - 1, # weight
+ NUMERIC_NEG if sign else NUMERIC_POS, # sign
+ dscale,
+ )
+ )
+
+ pgdigit = 0
+ for i in range(nzdigits):
+ pgdigit += weights[wi] * digits[i]
+ wi += 1
+ if wi >= DEC_DIGITS:
+ out += pack_uint2(pgdigit)
+ pgdigit = wi = 0
+
+ if pgdigit:
+ out += pack_uint2(pgdigit)
+
+ return out
+
+
+def dump_int_to_numeric_binary(obj: int) -> bytearray:
+ ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
+ out = bytearray(b"\x00\x00" * (ndigits + 4))
+ if obj < 0:
+ sign = NUMERIC_NEG
+ obj = -obj
+ else:
+ sign = NUMERIC_POS
+
+ out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
+ i = 8 + (ndigits - 1) * 2
+ while obj:
+ rem = obj % 10_000
+ obj //= 10_000
+ out[i : i + 2] = pack_uint2(rem)
+ i -= 2
+
+ return out
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(int, IntDumper)
+ adapters.register_dumper(int, IntBinaryDumper)
+ adapters.register_dumper(float, FloatDumper)
+ adapters.register_dumper(float, FloatBinaryDumper)
+ adapters.register_dumper(Int2, Int2Dumper)
+ adapters.register_dumper(Int4, Int4Dumper)
+ adapters.register_dumper(Int8, Int8Dumper)
+ adapters.register_dumper(IntNumeric, IntNumericDumper)
+ adapters.register_dumper(Oid, OidDumper)
+
+ # The binary dumper is currently some 30% slower, so default to text
+ # (see tests/scripts/testdec.py for a rough benchmark)
+ # Also, must be after IntNumericDumper
+ adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
+ adapters.register_dumper("decimal.Decimal", DecimalDumper)
+
+ # Used only by oid, can take both int and Decimal as input
+ adapters.register_dumper(None, NumericBinaryDumper)
+ adapters.register_dumper(None, NumericDumper)
+
+ adapters.register_dumper(Float4, Float4Dumper)
+ adapters.register_dumper(Float8, FloatDumper)
+ adapters.register_dumper(Int2, Int2BinaryDumper)
+ adapters.register_dumper(Int4, Int4BinaryDumper)
+ adapters.register_dumper(Int8, Int8BinaryDumper)
+ adapters.register_dumper(Oid, OidBinaryDumper)
+ adapters.register_dumper(Float4, Float4BinaryDumper)
+ adapters.register_dumper(Float8, FloatBinaryDumper)
+ adapters.register_loader("int2", IntLoader)
+ adapters.register_loader("int4", IntLoader)
+ adapters.register_loader("int8", IntLoader)
+ adapters.register_loader("oid", IntLoader)
+ adapters.register_loader("int2", Int2BinaryLoader)
+ adapters.register_loader("int4", Int4BinaryLoader)
+ adapters.register_loader("int8", Int8BinaryLoader)
+ adapters.register_loader("oid", OidBinaryLoader)
+ adapters.register_loader("float4", FloatLoader)
+ adapters.register_loader("float8", FloatLoader)
+ adapters.register_loader("float4", Float4BinaryLoader)
+ adapters.register_loader("float8", Float8BinaryLoader)
+ adapters.register_loader("numeric", NumericLoader)
+ adapters.register_loader("numeric", NumericBinaryLoader)
diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py
new file mode 100644
index 0000000..c418480
--- /dev/null
+++ b/psycopg/psycopg/types/range.py
@@ -0,0 +1,700 @@
+"""
+Support for range types adaptation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
+from typing import cast
+from decimal import Decimal
+from datetime import date, datetime
+
+from .. import errors as e
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey
+from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
+from .._struct import pack_len, unpack_len
+from ..postgres import INVALID_OID, TEXT_OID
+from .._typeinfo import RangeInfo as RangeInfo # exported here
+
+RANGE_EMPTY = 0x01 # range is empty
+RANGE_LB_INC = 0x02 # lower bound is inclusive
+RANGE_UB_INC = 0x04 # upper bound is inclusive
+RANGE_LB_INF = 0x08 # lower bound is -infinity
+RANGE_UB_INF = 0x10 # upper bound is +infinity
+
+_EMPTY_HEAD = bytes([RANGE_EMPTY])
+
+T = TypeVar("T")
+
+
+class Range(Generic[T]):
+ """Python representation for a PostgreSQL range type.
+
+ :param lower: lower bound for the range. `!None` means unbound
+ :param upper: upper bound for the range. `!None` means unbound
+ :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``,
+ representing whether the lower or upper bounds are included
+ :param empty: if `!True`, the range is empty
+
+ """
+
+ __slots__ = ("_lower", "_upper", "_bounds")
+
+ def __init__(
+ self,
+ lower: Optional[T] = None,
+ upper: Optional[T] = None,
+ bounds: str = "[)",
+ empty: bool = False,
+ ):
+ if not empty:
+ if bounds not in ("[)", "(]", "()", "[]"):
+ raise ValueError("bound flags not valid: %r" % bounds)
+
+ self._lower = lower
+ self._upper = upper
+
+ # Make bounds consistent with infs
+ if lower is None and bounds[0] == "[":
+ bounds = "(" + bounds[1]
+ if upper is None and bounds[1] == "]":
+ bounds = bounds[0] + ")"
+
+ self._bounds = bounds
+ else:
+ self._lower = self._upper = None
+ self._bounds = ""
+
+ def __repr__(self) -> str:
+ if self._bounds:
+ args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}"
+ else:
+ args = "empty=True"
+
+ return f"{self.__class__.__name__}({args})"
+
+ def __str__(self) -> str:
+ if not self._bounds:
+ return "empty"
+
+ items = [
+ self._bounds[0],
+ str(self._lower),
+ ", ",
+ str(self._upper),
+ self._bounds[1],
+ ]
+ return "".join(items)
+
+ @property
+ def lower(self) -> Optional[T]:
+ """The lower bound of the range. `!None` if empty or unbound."""
+ return self._lower
+
+ @property
+ def upper(self) -> Optional[T]:
+ """The upper bound of the range. `!None` if empty or unbound."""
+ return self._upper
+
+ @property
+ def bounds(self) -> str:
+ """The bounds string (two characters from '[', '(', ']', ')')."""
+ return self._bounds
+
+ @property
+ def isempty(self) -> bool:
+ """`!True` if the range is empty."""
+ return not self._bounds
+
+ @property
+ def lower_inf(self) -> bool:
+ """`!True` if the range doesn't have a lower bound."""
+ if not self._bounds:
+ return False
+ return self._lower is None
+
+ @property
+ def upper_inf(self) -> bool:
+ """`!True` if the range doesn't have an upper bound."""
+ if not self._bounds:
+ return False
+ return self._upper is None
+
+ @property
+ def lower_inc(self) -> bool:
+ """`!True` if the lower bound is included in the range."""
+ if not self._bounds or self._lower is None:
+ return False
+ return self._bounds[0] == "["
+
+ @property
+ def upper_inc(self) -> bool:
+ """`!True` if the upper bound is included in the range."""
+ if not self._bounds or self._upper is None:
+ return False
+ return self._bounds[1] == "]"
+
+ def __contains__(self, x: T) -> bool:
+ if not self._bounds:
+ return False
+
+ if self._lower is not None:
+ if self._bounds[0] == "[":
+ # It doesn't seem that Python has an ABC for ordered types.
+ if x < self._lower: # type: ignore[operator]
+ return False
+ else:
+ if x <= self._lower: # type: ignore[operator]
+ return False
+
+ if self._upper is not None:
+ if self._bounds[1] == "]":
+ if x > self._upper: # type: ignore[operator]
+ return False
+ else:
+ if x >= self._upper: # type: ignore[operator]
+ return False
+
+ return True
+
+ def __bool__(self) -> bool:
+ return bool(self._bounds)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return False
+ return (
+ self._lower == other._lower
+ and self._upper == other._upper
+ and self._bounds == other._bounds
+ )
+
+ def __hash__(self) -> int:
+ return hash((self._lower, self._upper, self._bounds))
+
+ # as the postgres docs describe for the server-side stuff,
+ # ordering is rather arbitrary, but will remain stable
+ # and consistent.
+
+ def __lt__(self, other: Any) -> bool:
+ if not isinstance(other, Range):
+ return NotImplemented
+ for attr in ("_lower", "_upper", "_bounds"):
+ self_value = getattr(self, attr)
+ other_value = getattr(other, attr)
+ if self_value == other_value:
+ pass
+ elif self_value is None:
+ return True
+ elif other_value is None:
+ return False
+ else:
+ return cast(bool, self_value < other_value)
+ return False
+
+ def __le__(self, other: Any) -> bool:
+ return self == other or self < other # type: ignore
+
+ def __gt__(self, other: Any) -> bool:
+ if isinstance(other, Range):
+ return other < self
+ else:
+ return NotImplemented
+
+ def __ge__(self, other: Any) -> bool:
+ return self == other or self > other # type: ignore
+
+ def __getstate__(self) -> Dict[str, Any]:
+ return {
+ slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
+ }
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ for slot, value in state.items():
+ setattr(self, slot, value)
+
+
+# Subclasses to specify a specific subtype. Usually not needed: only needed
+# in binary copy, where switching to text is not an option.
+
+
+class Int4Range(Range[int]):
+ pass
+
+
+class Int8Range(Range[int]):
+ pass
+
+
+class NumericRange(Range[Decimal]):
+ pass
+
+
+class DateRange(Range[date]):
+ pass
+
+
+class TimestampRange(Range[datetime]):
+ pass
+
+
+class TimestamptzRange(Range[datetime]):
+ pass
+
+
+class BaseRangeDumper(RecursiveDumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self.sub_dumper: Optional[Dumper] = None
+ self._adapt_format = PyFormat.from_pq(self.format)
+
+ def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self.cls
+
+ item = self._get_item(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, self._adapt_format)
+ return (self.cls, sd.get_key(item, format))
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.cls is not Range:
+ return self
+
+ item = self._get_item(obj)
+ if item is None:
+ return RangeDumper(self.cls)
+
+ dumper: BaseRangeDumper
+ if type(item) is int:
+ # postgres won't cast int4range -> int8range so we must use
+ # text format and unknown oid here
+ sd = self._tx.get_dumper(item, PyFormat.TEXT)
+ dumper = RangeDumper(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ dumper.oid = INVALID_OID
+ return dumper
+
+ sd = self._tx.get_dumper(item, format)
+ dumper = type(self)(self.cls, self._tx)
+ dumper.sub_dumper = sd
+ if sd.oid == INVALID_OID and isinstance(item, str):
+ # Work around the normal mapping where text is dumped as unknown
+ dumper.oid = self._get_range_oid(TEXT_OID)
+ else:
+ dumper.oid = self._get_range_oid(sd.oid)
+
+ return dumper
+
+ def _get_item(self, obj: Range[Any]) -> Any:
+ """
+ Return a member representative of the range
+ """
+ rv = obj.lower
+ return rv if rv is not None else obj.upper
+
+ def _get_range_oid(self, sub_oid: int) -> int:
+ """
+ Return the oid of the range from the oid of its elements.
+ """
+ info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
+ return info.oid if info else INVALID_OID
+
+
+class RangeDumper(BaseRangeDumper):
+ """
+ Dumper for range types.
+
+ The dumper can upgrade to one specific for a different range type.
+ """
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_text(obj, dump)
+
+
+def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if obj.isempty:
+ return b"empty"
+
+ parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
+
+ def dump_item(item: Any) -> Buffer:
+ ad = dump(item)
+ if not ad:
+ return b'""'
+ elif _re_needs_quotes.search(ad):
+ return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"'
+ else:
+ return ad
+
+ if obj.lower is not None:
+ parts.append(dump_item(obj.lower))
+
+ parts.append(b",")
+
+ if obj.upper is not None:
+ parts.append(dump_item(obj.upper))
+
+ parts.append(b"]" if obj.upper_inc else b")")
+
+ return b"".join(parts)
+
+
+_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]')
+_re_esc = re.compile(rb"([\\\"])")
+
+
+class RangeBinaryDumper(BaseRangeDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: Range[Any]) -> Buffer:
+ item = self._get_item(obj)
+ if item is not None:
+ dump = self._tx.get_dumper(item, self._adapt_format).dump
+ else:
+ dump = fail_dump
+
+ return dump_range_binary(obj, dump)
+
+
+def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
+ if not obj:
+ return _EMPTY_HEAD
+
+ out = bytearray([0]) # will replace the head later
+
+ head = 0
+ if obj.lower_inc:
+ head |= RANGE_LB_INC
+ if obj.upper_inc:
+ head |= RANGE_UB_INC
+
+ if obj.lower is not None:
+ data = dump(obj.lower)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_LB_INF
+
+ if obj.upper is not None:
+ data = dump(obj.upper)
+ out += pack_len(len(data))
+ out += data
+ else:
+ head |= RANGE_UB_INF
+
+ out[0] = head
+ return out
+
+
+def fail_dump(obj: Any) -> Buffer:
+ raise e.InternalError("trying to dump a range element without information")
+
+
+class BaseRangeLoader(RecursiveLoader, Generic[T]):
+ """Generic loader for a range.
+
+ Subclasses must specify the oid of the subtype and the class to load.
+ """
+
+ subtype_oid: int
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
+
+
+class RangeLoader(BaseRangeLoader[T]):
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_text(data, self._load)[0]
+
+
+def load_range_text(
+ data: Buffer, load: Callable[[Buffer], Any]
+) -> Tuple[Range[Any], int]:
+ if data == b"empty":
+ return Range(empty=True), 5
+
+ m = _re_range.match(data)
+ if m is None:
+ raise e.DataError(
+ f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
+ )
+
+ lower = None
+ item = m.group(3)
+ if item is None:
+ item = m.group(2)
+ if item is not None:
+ lower = load(_re_undouble.sub(rb"\1", item))
+ else:
+ lower = load(item)
+
+ upper = None
+ item = m.group(5)
+ if item is None:
+ item = m.group(4)
+ if item is not None:
+ upper = load(_re_undouble.sub(rb"\1", item))
+ else:
+ upper = load(item)
+
+ bounds = (m.group(1) + m.group(6)).decode()
+
+ return Range(lower, upper, bounds), m.end()
+
+
+_re_range = re.compile(
+ rb"""
+ ( \(|\[ ) # lower bound flag
+ (?: # lower bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^",]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ,
+ (?: # upper bound:
+ " ( (?: [^"] | "")* ) " # - a quoted string
+ | ( [^"\)\]]+ ) # - or an unquoted string
+ )? # - or empty (not caught)
+ ( \)|\] ) # upper bound flag
+ """,
+ re.VERBOSE,
+)
+
+_re_undouble = re.compile(rb'(["\\])\1')
+
+
+class RangeBinaryLoader(BaseRangeLoader[T]):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Range[T]:
+ return load_range_binary(data, self._load)
+
+
+def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
+ head = data[0]
+ if head & RANGE_EMPTY:
+ return Range(empty=True)
+
+ lb = "[" if head & RANGE_LB_INC else "("
+ ub = "]" if head & RANGE_UB_INC else ")"
+
+ pos = 1 # after the head
+ if head & RANGE_LB_INF:
+ min = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ min = load(data[pos : pos + length])
+ pos += length
+
+ if head & RANGE_UB_INF:
+ max = None
+ else:
+ length = unpack_len(data, pos)[0]
+ pos += 4
+ max = load(data[pos : pos + length])
+ pos += length
+
+ return Range(min, max, lb + ub)
+
+
+def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register the adapters to load and dump a range type.
+
+ :param info: The object with the information about the range to register.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ Register loaders so that loading data of this type will result in a `Range`
+ with bounds parsed as the right subtype.
+
+ .. note::
+
+ Registering the adapters doesn't affect objects already created, even
+ if they are children of the registered context. For instance,
+ registering the adapter globally doesn't affect already existing
+ connections.
+ """
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the requested range available?")
+
+ # Register arrays and type info
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # generate and register a customized text loader
+ loader: Type[RangeLoader[Any]] = type(
+ f"{info.name.title()}Loader",
+ (RangeLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, loader)
+
+ # generate and register a customized binary loader
+ bloader: Type[RangeBinaryLoader[Any]] = type(
+ f"{info.name.title()}BinaryLoader",
+ (RangeBinaryLoader,),
+ {"subtype_oid": info.subtype_oid},
+ )
+ adapters.register_loader(info.oid, bloader)
+
+
+# Text dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeDumper(RangeDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeDumper(RangeDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeDumper(RangeDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeDumper(RangeDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeDumper(RangeDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeDumper(RangeDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Binary dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int4range"].oid
+
+
+class Int8RangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["int8range"].oid
+
+
+class NumericRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["numrange"].oid
+
+
+class DateRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["daterange"].oid
+
+
+class TimestampRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tsrange"].oid
+
+
+class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
+ oid = postgres.types["tstzrange"].oid
+
+
+# Text loaders for builtin range types
+
+
+class Int4RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeLoader(RangeLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeLoader(RangeLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeLoader(RangeLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeLoader(RangeLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+# Binary loaders for builtin range types
+
+
+class Int4RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int4"].oid
+
+
+class Int8RangeBinaryLoader(RangeBinaryLoader[int]):
+ subtype_oid = postgres.types["int8"].oid
+
+
+class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]):
+ subtype_oid = postgres.types["numeric"].oid
+
+
+class DateRangeBinaryLoader(RangeBinaryLoader[date]):
+ subtype_oid = postgres.types["date"].oid
+
+
+class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamp"].oid
+
+
+class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]):
+ subtype_oid = postgres.types["timestamptz"].oid
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper(Range, RangeBinaryDumper)
+ adapters.register_dumper(Range, RangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeDumper)
+ adapters.register_dumper(Int8Range, Int8RangeDumper)
+ adapters.register_dumper(NumericRange, NumericRangeDumper)
+ adapters.register_dumper(DateRange, DateRangeDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper)
+ adapters.register_dumper(Int4Range, Int4RangeBinaryDumper)
+ adapters.register_dumper(Int8Range, Int8RangeBinaryDumper)
+ adapters.register_dumper(NumericRange, NumericRangeBinaryDumper)
+ adapters.register_dumper(DateRange, DateRangeBinaryDumper)
+ adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper)
+ adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper)
+ adapters.register_loader("int4range", Int4RangeLoader)
+ adapters.register_loader("int8range", Int8RangeLoader)
+ adapters.register_loader("numrange", NumericRangeLoader)
+ adapters.register_loader("daterange", DateRangeLoader)
+ adapters.register_loader("tsrange", TimestampRangeLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeLoader)
+ adapters.register_loader("int4range", Int4RangeBinaryLoader)
+ adapters.register_loader("int8range", Int8RangeBinaryLoader)
+ adapters.register_loader("numrange", NumericRangeBinaryLoader)
+ adapters.register_loader("daterange", DateRangeBinaryLoader)
+ adapters.register_loader("tsrange", TimestampRangeBinaryLoader)
+ adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader)
diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py
new file mode 100644
index 0000000..e99f256
--- /dev/null
+++ b/psycopg/psycopg/types/shapely.py
@@ -0,0 +1,75 @@
+"""
+Adapters for PostGIS geometries
+"""
+
+from typing import Optional
+
+from .. import postgres
+from ..abc import AdaptContext, Buffer
+from ..adapt import Dumper, Loader
+from ..pq import Format
+from .._typeinfo import TypeInfo
+
+
+try:
+ from shapely.wkb import loads, dumps
+ from shapely.geometry.base import BaseGeometry
+
+except ImportError:
+ raise ImportError(
+ "The module psycopg.types.shapely requires the package 'Shapely'"
+ " to be installed"
+ )
+
+
+class GeometryBinaryLoader(Loader):
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "BaseGeometry":
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return loads(data)
+
+
+class GeometryLoader(Loader):
+ def load(self, data: Buffer) -> "BaseGeometry":
+ # it's a hex string in binary
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return loads(data.decode(), hex=True)
+
+
+class BaseGeometryBinaryDumper(Dumper):
+ format = Format.BINARY
+
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj) # type: ignore
+
+
+class BaseGeometryDumper(Dumper):
+ def dump(self, obj: "BaseGeometry") -> bytes:
+ return dumps(obj, hex=True).encode() # type: ignore
+
+
+def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
+ """Register Shapely dumper and loaders."""
+
+ # A friendly error warning instead of an AttributeError in case fetch()
+ # failed and it wasn't noticed.
+ if not info:
+ raise TypeError("no info passed. Is the 'postgis' extension loaded?")
+
+ info.register(context)
+ adapters = context.adapters if context else postgres.adapters
+
+ class GeometryDumper(BaseGeometryDumper):
+ oid = info.oid
+
+ class GeometryBinaryDumper(BaseGeometryBinaryDumper):
+ oid = info.oid
+
+ adapters.register_loader(info.oid, GeometryBinaryLoader)
+ adapters.register_loader(info.oid, GeometryLoader)
+ # Default binary dump
+ adapters.register_dumper(BaseGeometry, GeometryDumper)
+ adapters.register_dumper(BaseGeometry, GeometryBinaryDumper)
diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py
new file mode 100644
index 0000000..cd5360d
--- /dev/null
+++ b/psycopg/psycopg/types/string.py
@@ -0,0 +1,239 @@
+"""
+Adapters for textual types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Optional, Union, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format, Escaping
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+from ..errors import DataError
+from .._encodings import conn_encoding
+
+if TYPE_CHECKING:
+ from ..pq.abc import Escaping as EscapingProto
+
+
+class _BaseStrDumper(Dumper):
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else "utf-8"
+
+
+class _StrBinaryDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in binary format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ format = Format.BINARY
+
+ def dump(self, obj: str) -> bytes:
+ # the server will raise DataError subclass if the string contains 0x00
+ return obj.encode(self._encoding)
+
+
+class _StrDumper(_BaseStrDumper):
+ """
+ Base class to dump a Python strings to a Postgres text type, in text format.
+
+ Subclasses shall specify the oids of real types (text, varchar, name...).
+ """
+
+ def dump(self, obj: str) -> bytes:
+ if "\x00" in obj:
+ raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
+ else:
+ return obj.encode(self._encoding)
+
+
+# The next are concrete dumpers, each one specifying the oid they dump to.
+
+
+class StrBinaryDumper(_StrBinaryDumper):
+
+ oid = postgres.types["text"].oid
+
+
+class StrBinaryDumperVarchar(_StrBinaryDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrBinaryDumperName(_StrBinaryDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumper(_StrDumper):
+ """
+ Dumper for strings in text format to the text oid.
+
+ Note that this dumper is not used by default because the type is too strict
+ and PostgreSQL would require an explicit casts to everything that is not a
+ text field. However it is useful where the unknown oid is ambiguous and the
+ text oid is required, for instance with variadic functions.
+ """
+
+ oid = postgres.types["text"].oid
+
+
+class StrDumperVarchar(_StrDumper):
+
+ oid = postgres.types["varchar"].oid
+
+
+class StrDumperName(_StrDumper):
+
+ oid = postgres.types["name"].oid
+
+
+class StrDumperUnknown(_StrDumper):
+ """
+ Dumper for strings in text format to the unknown oid.
+
+ This dumper is the default dumper for strings and allows to use Python
+ strings to represent almost every data type. In a few places, however, the
+ unknown oid is not accepted (for instance in variadic functions such as
+ 'concat()'). In that case either a cast on the placeholder ('%s::text') or
+ the StrTextDumper should be used.
+ """
+
+ pass
+
+
+class TextLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ enc = conn_encoding(self.connection)
+ self._encoding = enc if enc != "ascii" else ""
+
+ def load(self, data: Buffer) -> Union[bytes, str]:
+ if self._encoding:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return data.decode(self._encoding)
+ else:
+ # return bytes for SQL_ASCII db
+ if not isinstance(data, bytes):
+ data = bytes(data)
+ return data
+
+
+class TextBinaryLoader(TextLoader):
+
+ format = Format.BINARY
+
+
+class BytesDumper(Dumper):
+
+ oid = postgres.types["bytea"].oid
+ _qprefix = b""
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ super().__init__(cls, context)
+ self._esc = Escaping(self.connection.pgconn if self.connection else None)
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return self._esc.escape_bytea(obj)
+
+ def quote(self, obj: Buffer) -> bytes:
+ escaped = self.dump(obj)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self.connection:
+ if not self._qprefix:
+ scs = self.connection.pgconn.parameter_status(
+ b"standard_conforming_strings"
+ )
+ self._qprefix = b"'" if scs == b"on" else b" E'"
+
+ return self._qprefix + escaped + b"'"
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ rv: bytes = b" E'" + escaped + b"'"
+ if self._esc.escape_bytea(b"\x00") == b"\\000":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
+
+
+class BytesBinaryDumper(Dumper):
+
+ format = Format.BINARY
+ oid = postgres.types["bytea"].oid
+
+ def dump(self, obj: Buffer) -> Buffer:
+ return obj
+
+
+class ByteaLoader(Loader):
+
+ _escaping: "EscapingProto"
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ if not hasattr(self.__class__, "_escaping"):
+ self.__class__._escaping = Escaping()
+
+ def load(self, data: Buffer) -> bytes:
+ return self._escaping.unescape_bytea(data)
+
+
+class ByteaBinaryLoader(Loader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> Buffer:
+ return data
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+
+ # NOTE: the order the dumpers are registered is relevant. The last one
+ # registered becomes the default for each type. Usually, binary is the
+ # default dumper. For text we use the text dumper as default because it
+ # plays the role of unknown, and it can be cast automatically to other
+ # types. However, before that, we register dumper with 'text', 'varchar',
+ # 'name' oids, which will be used when a text dumper is looked up by oid.
+ adapters.register_dumper(str, StrBinaryDumperName)
+ adapters.register_dumper(str, StrBinaryDumperVarchar)
+ adapters.register_dumper(str, StrBinaryDumper)
+ adapters.register_dumper(str, StrDumperName)
+ adapters.register_dumper(str, StrDumperVarchar)
+ adapters.register_dumper(str, StrDumper)
+ adapters.register_dumper(str, StrDumperUnknown)
+
+ adapters.register_loader(postgres.INVALID_OID, TextLoader)
+ adapters.register_loader("bpchar", TextLoader)
+ adapters.register_loader("name", TextLoader)
+ adapters.register_loader("text", TextLoader)
+ adapters.register_loader("varchar", TextLoader)
+ adapters.register_loader('"char"', TextLoader)
+ adapters.register_loader("bpchar", TextBinaryLoader)
+ adapters.register_loader("name", TextBinaryLoader)
+ adapters.register_loader("text", TextBinaryLoader)
+ adapters.register_loader("varchar", TextBinaryLoader)
+ adapters.register_loader('"char"', TextBinaryLoader)
+
+ adapters.register_dumper(bytes, BytesDumper)
+ adapters.register_dumper(bytearray, BytesDumper)
+ adapters.register_dumper(memoryview, BytesDumper)
+ adapters.register_dumper(bytes, BytesBinaryDumper)
+ adapters.register_dumper(bytearray, BytesBinaryDumper)
+ adapters.register_dumper(memoryview, BytesBinaryDumper)
+
+ adapters.register_loader("bytea", ByteaLoader)
+ adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader)
+ adapters.register_loader("bytea", ByteaBinaryLoader)
diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py
new file mode 100644
index 0000000..f92354c
--- /dev/null
+++ b/psycopg/psycopg/types/uuid.py
@@ -0,0 +1,65 @@
+"""
+Adapters for the UUID type.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Callable, Optional, TYPE_CHECKING
+
+from .. import postgres
+from ..pq import Format
+from ..abc import AdaptContext
+from ..adapt import Buffer, Dumper, Loader
+
+if TYPE_CHECKING:
+ import uuid
+
+# Importing the uuid module is slow, so import it only on request.
+UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment]
+
+
+class UUIDDumper(Dumper):
+
+ oid = postgres.types["uuid"].oid
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.hex.encode()
+
+
+class UUIDBinaryDumper(UUIDDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: "uuid.UUID") -> bytes:
+ return obj.bytes
+
+
+class UUIDLoader(Loader):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+ global UUID
+ if UUID is None:
+ from uuid import UUID
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(data.decode())
+
+
+class UUIDBinaryLoader(UUIDLoader):
+
+ format = Format.BINARY
+
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return UUID(bytes=data)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ adapters = context.adapters
+ adapters.register_dumper("uuid.UUID", UUIDDumper)
+ adapters.register_dumper("uuid.UUID", UUIDBinaryDumper)
+ adapters.register_loader("uuid", UUIDLoader)
+ adapters.register_loader("uuid", UUIDBinaryLoader)
diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py
new file mode 100644
index 0000000..a98bc35
--- /dev/null
+++ b/psycopg/psycopg/version.py
@@ -0,0 +1,14 @@
+"""
+psycopg distribution version file.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+
+# STOP AND READ! if you change:
+__version__ = "3.1.7"
+# also change:
+# - `docs/news.rst` to declare this as the current version or an unreleased one
+# - `psycopg_c/psycopg_c/version.py` to the same version.
diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py
new file mode 100644
index 0000000..7abfc58
--- /dev/null
+++ b/psycopg/psycopg/waiting.py
@@ -0,0 +1,331 @@
+"""
+Code concerned with waiting in different contexts (blocking, async, etc).
+
+These functions are designed to consume the generators returned by the
+`generators` module function and to return their final value.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+import os
+import select
+import selectors
+from typing import Dict, Optional
+from asyncio import get_event_loop, wait_for, Event, TimeoutError
+from selectors import DefaultSelector
+
+from . import errors as e
+from .abc import RV, PQGen, PQGenConn, WaitFunc
+from ._enums import Wait as Wait, Ready as Ready # re-exported
+from ._cmodule import _psycopg
+
+WAIT_R = Wait.R
+WAIT_W = Wait.W
+WAIT_RW = Wait.RW
+READY_R = Ready.R
+READY_W = Ready.W
+READY_RW = Ready.RW
+
+
+def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Consume `!gen`, scheduling `fileno` for completion when it is reported to
+ block. Once ready again send the ready state back to `!gen`.
+ """
+ try:
+ s = next(gen)
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = None
+ while not rlist:
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ # note: this line should require a cast, but mypy doesn't complain
+ ready: Ready = rlist[0][1]
+ assert s & ready
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a connection generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :type timeout: float
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ with DefaultSelector() as sel:
+ while True:
+ sel.register(fileno, s)
+ rlist = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ if not rlist:
+ raise e.ConnectionTimeout("connection timeout expired")
+ ready: Ready = rlist[0][1] # type: ignore[assignment]
+ fileno, s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
+ """
+ Coroutine waiting for a generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but exposing an `asyncio` interface.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready |= state # type: ignore[assignment]
+ ev.set()
+
+ try:
+ s = next(gen)
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await ev.wait()
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Coroutine waiting for a connection generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
+ :return: whatever `!gen` returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready = state
+ ev.set()
+
+ try:
+ fileno, s = next(gen)
+ if not timeout:
+ timeout = None
+ while True:
+ reader = s & WAIT_R
+ writer = s & WAIT_W
+ if not reader and not writer:
+ raise e.InternalError(f"bad poll status: {s}")
+ ev.clear()
+ ready = 0 # type: ignore[assignment]
+ if reader:
+ loop.add_reader(fileno, wakeup, READY_R)
+ if writer:
+ loop.add_writer(fileno, wakeup, READY_W)
+ try:
+ await wait_for(ev.wait(), timeout)
+ finally:
+ if reader:
+ loop.remove_reader(fileno)
+ if writer:
+ loop.remove_writer(fileno)
+ fileno, s = gen.send(ready)
+
+ except TimeoutError:
+ raise e.ConnectionTimeout("connection timeout expired")
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+# Specialised implementation of wait functions.
+
+
+def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using select where supported.
+ """
+ try:
+ s = next(gen)
+
+ empty = ()
+ fnlist = (fileno,)
+ while True:
+ rl, wl, xl = select.select(
+ fnlist if s & WAIT_R else empty,
+ fnlist if s & WAIT_W else empty,
+ fnlist,
+ timeout,
+ )
+ ready = 0
+ if rl:
+ ready = READY_R
+ if wl:
+ ready |= READY_W
+ if not ready:
+ continue
+ # assert s & ready
+ s = gen.send(ready) # type: ignore
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+poll_evmasks: Dict[Wait, int]
+
+if hasattr(selectors, "EpollSelector"):
+ poll_evmasks = {
+ WAIT_R: select.EPOLLONESHOT | select.EPOLLIN,
+ WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT,
+ WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT,
+ }
+else:
+ poll_evmasks = {}
+
+
+def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a generator using epoll where supported.
+
+ Parameters are like for `wait()`. If it is detected that the best selector
+ strategy is `epoll` then this function will be used instead of `wait`.
+
+ See also: https://linux.die.net/man/2/epoll_ctl
+ """
+ try:
+ s = next(gen)
+
+ if timeout is None or timeout < 0:
+ timeout = 0
+ else:
+ timeout = int(timeout * 1000.0)
+
+ with select.epoll() as epoll:
+ evmask = poll_evmasks[s]
+ epoll.register(fileno, evmask)
+ while True:
+ fileevs = None
+ while not fileevs:
+ fileevs = epoll.poll(timeout)
+ ev = fileevs[0][1]
+ ready = 0
+ if ev & ~select.EPOLLOUT:
+ ready = READY_R
+ if ev & ~select.EPOLLIN:
+ ready |= READY_W
+ # assert s & ready
+ s = gen.send(ready)
+ evmask = poll_evmasks[s]
+ epoll.modify(fileno, evmask)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+if _psycopg:
+ wait_c = _psycopg.wait_c
+
+
+# Choose the best wait strategy for the platform.
+#
+# the selectors objects have a generic interface but come with some overhead,
+# so we also offer more finely tuned implementations.
+
+wait: WaitFunc
+
+# Allow the user to choose a specific function for testing
+if "PSYCOPG_WAIT_FUNC" in os.environ:
+ fname = os.environ["PSYCOPG_WAIT_FUNC"]
+ if not fname.startswith("wait_") or fname not in globals():
+ raise ImportError(
+ "PSYCOPG_WAIT_FUNC should be the name of an available wait function;"
+ f" got {fname!r}"
+ )
+ wait = globals()[fname]
+
+elif _psycopg:
+ wait = wait_c
+
+elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
+ # On Windows, SelectSelector should be the default.
+ wait = wait_select
+
+elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None):
+ # NOTE: select seems more performing than epoll. It is admittedly unlikely
+ # that a platform has epoll but not select, so maybe we could kill
+ # wait_epoll altogether(). More testing to do.
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll
+
+elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None):
+ # wait_select is faster than wait_selector, probably because of less overhead
+ wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector
+
+else:
+ wait = wait_selector
diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml
new file mode 100644
index 0000000..21e410c
--- /dev/null
+++ b/psycopg/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37"]
+build-backend = "setuptools.build_meta"
diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg
new file mode 100644
index 0000000..fdcb612
--- /dev/null
+++ b/psycopg/setup.cfg
@@ -0,0 +1,47 @@
+[metadata]
+name = psycopg
+description = PostgreSQL database adapter for Python
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+packages = find:
+zip_safe = False
+install_requires =
+ backports.zoneinfo >= 0.2.0; python_version < "3.9"
+ typing-extensions >= 4.1
+ tzdata; sys_platform == "win32"
+
+[options.package_data]
+psycopg = py.typed
diff --git a/psycopg/setup.py b/psycopg/setup.py
new file mode 100644
index 0000000..90d4380
--- /dev/null
+++ b/psycopg/setup.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - pure Python package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+from setuptools import setup
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+# Only for release 3.1.7. Not building binary packages because Scaleway
+# has no runner available, but psycopg-binary 3.1.6 should work as well
+# as the only change is in rows.py.
+version = "3.1.7"
+ext_versions = ">= 3.1.6, <= 3.1.7"
+
+extras_require = {
+ # Install the C extension module (requires dev tools)
+ "c": [
+ f"psycopg-c {ext_versions}",
+ ],
+ # Install the stand-alone C extension module
+ "binary": [
+ f"psycopg-binary {ext_versions}",
+ ],
+ # Install the connection pool
+ "pool": [
+ "psycopg-pool",
+ ],
+ # Requirements to run the test suite
+ "test": [
+ "mypy >= 0.990",
+ "pproxy >= 2.7",
+ "pytest >= 6.2.5",
+ "pytest-asyncio >= 0.17",
+ "pytest-cov >= 3.0",
+ "pytest-randomly >= 3.10",
+ ],
+ # Requirements needed for development
+ "dev": [
+ "black >= 22.3.0",
+ "dnspython >= 2.1",
+ "flake8 >= 4.0",
+ "mypy >= 0.990",
+ "types-setuptools >= 57.4",
+ "wheel >= 0.37",
+ ],
+ # Requirements needed to build the documentation
+ "docs": [
+ "Sphinx >= 5.0",
+ "furo == 2022.6.21",
+ "sphinx-autobuild >= 2021.3.14",
+ "sphinx-autodoc-typehints >= 1.12",
+ ],
+}
+
+setup(
+ version=version,
+ extras_require=extras_require,
+)
diff --git a/psycopg_c/.flake8 b/psycopg_c/.flake8
new file mode 100644
index 0000000..2ae629c
--- /dev/null
+++ b/psycopg_c/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
diff --git a/psycopg_c/LICENSE.txt b/psycopg_c/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg_c/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg_c/README-binary.rst b/psycopg_c/README-binary.rst
new file mode 100644
index 0000000..9318d57
--- /dev/null
+++ b/psycopg_c/README-binary.rst
@@ -0,0 +1,29 @@
+Psycopg 3: PostgreSQL database adapter for Python - binary package
+==================================================================
+
+This distribution contains the precompiled optimization package
+``psycopg_binary``.
+
+You shouldn't install this package directly: use instead ::
+
+ pip install "psycopg[binary]"
+
+to install a version of the optimization package matching the ``psycopg``
+version installed.
+
+Installing this package requires pip >= 20.3 or newer installed.
+
+This package is not available for every platform: check out `Binary
+installation`__ in the documentation.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #binary-installation
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_c/README.rst b/psycopg_c/README.rst
new file mode 100644
index 0000000..de9ba93
--- /dev/null
+++ b/psycopg_c/README.rst
@@ -0,0 +1,33 @@
+Psycopg 3: PostgreSQL database adapter for Python - optimisation package
+========================================================================
+
+This distribution contains the optional optimization package ``psycopg_c``.
+
+You shouldn't install this package directly: use instead ::
+
+ pip install "psycopg[c]"
+
+to install a version of the optimization package matching the ``psycopg``
+version installed.
+
+Installing this package requires some prerequisites: check `Local
+installation`__ in the documentation. Without a C compiler and some library
+headers install *will fail*: this is not a bug.
+
+If you are unable to meet the prerequisite needed you might want to install
+``psycopg[binary]`` instead: look for `Binary installation`__ in the
+documentation.
+
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #local-installation
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #binary-installation
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_c/psycopg_c/.gitignore b/psycopg_c/psycopg_c/.gitignore
new file mode 100644
index 0000000..36edb64
--- /dev/null
+++ b/psycopg_c/psycopg_c/.gitignore
@@ -0,0 +1,4 @@
+/*.so
+_psycopg.c
+pq.c
+*.html
diff --git a/psycopg_c/psycopg_c/__init__.py b/psycopg_c/psycopg_c/__init__.py
new file mode 100644
index 0000000..14db92b
--- /dev/null
+++ b/psycopg_c/psycopg_c/__init__.py
@@ -0,0 +1,14 @@
+"""
+psycopg -- PostgreSQL database adapter for Python -- C optimization package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import sys
+
+# This package shouldn't be imported before psycopg itself, or weird things
+# will happen
+if "psycopg" not in sys.modules:
+ raise ImportError("the psycopg package should be imported before psycopg_c")
+
+from .version import __version__ as __version__ # noqa
diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi
new file mode 100644
index 0000000..bd7c63d
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg.pyi
@@ -0,0 +1,84 @@
+"""
+Stub representaton of the public objects exposed by the _psycopg module.
+
+TODO: this should be generated by mypy's stubgen but it crashes with no
+information. Will submit a bug.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Iterable, List, Optional, Sequence, Tuple
+
+from psycopg import pq
+from psycopg import abc
+from psycopg.rows import Row, RowMaker
+from psycopg.adapt import AdaptersMap, PyFormat
+from psycopg.pq.abc import PGconn, PGresult
+from psycopg.connection import BaseConnection
+from psycopg._compat import Deque
+
+class Transformer(abc.AdaptContext):
+ types: Optional[Tuple[int, ...]]
+ formats: Optional[List[pq.Format]]
+ def __init__(self, context: Optional[abc.AdaptContext] = None): ...
+ @classmethod
+ def from_context(cls, context: Optional[abc.AdaptContext]) -> "Transformer": ...
+ @property
+ def connection(self) -> Optional[BaseConnection[Any]]: ...
+ @property
+ def encoding(self) -> str: ...
+ @property
+ def adapters(self) -> AdaptersMap: ...
+ @property
+ def pgresult(self) -> Optional[PGresult]: ...
+ def set_pgresult(
+ self,
+ result: Optional["PGresult"],
+ *,
+ set_loaders: bool = True,
+ format: Optional[pq.Format] = None,
+ ) -> None: ...
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+ def dump_sequence(
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
+ ) -> Sequence[Optional[abc.Buffer]]: ...
+ def as_literal(self, obj: Any) -> bytes: ...
+ def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
+ def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
+ def load_sequence(
+ self, record: Sequence[Optional[abc.Buffer]]
+ ) -> Tuple[Any, ...]: ...
+ def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
+
+# Generators
+def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ...
+def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
+def send(pgconn: PGconn) -> abc.PQGen[None]: ...
+def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
+def fetch(pgconn: PGconn) -> abc.PQGen[Optional[PGresult]]: ...
+def pipeline_communicate(
+ pgconn: PGconn, commands: Deque[abc.PipelineCommand]
+) -> abc.PQGen[List[List[PGresult]]]: ...
+def wait_c(
+ gen: abc.PQGen[abc.RV], fileno: int, timeout: Optional[float] = None
+) -> abc.RV: ...
+
+# Copy support
+def format_row_text(
+ row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None
+) -> bytearray: ...
+def format_row_binary(
+ row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None
+) -> bytearray: ...
+def parse_row_text(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_binary(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+
+# Arrays optimization
+def array_load_text(
+ data: abc.Buffer, loader: abc.Loader, delimiter: bytes = b","
+) -> List[Any]: ...
+def array_load_binary(data: abc.Buffer, tx: abc.Transformer) -> List[Any]: ...
+
+# vim: set syntax=python:
diff --git a/psycopg_c/psycopg_c/_psycopg.pyx b/psycopg_c/psycopg_c/_psycopg.pyx
new file mode 100644
index 0000000..9d2b8ba
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg.pyx
@@ -0,0 +1,48 @@
+"""
+psycopg_c._psycopg optimization module.
+
+The module contains optimized C code used in preference to Python code
+if a compiler is available.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c cimport pq
+from psycopg_c.pq cimport libpq
+from psycopg_c._psycopg cimport oids
+
+import logging
+
+from psycopg.pq import Format as _pq_Format
+from psycopg._enums import PyFormat as _py_Format
+
+logger = logging.getLogger("psycopg")
+
+PQ_TEXT = _pq_Format.TEXT
+PQ_BINARY = _pq_Format.BINARY
+
+PG_AUTO = _py_Format.AUTO
+PG_TEXT = _py_Format.TEXT
+PG_BINARY = _py_Format.BINARY
+
+
+cdef extern from *:
+ """
+#ifndef ARRAYSIZE
+#define ARRAYSIZE(a) ((sizeof(a) / sizeof(*(a))))
+#endif
+ """
+ int ARRAYSIZE(void *array)
+
+
+include "_psycopg/adapt.pyx"
+include "_psycopg/copy.pyx"
+include "_psycopg/generators.pyx"
+include "_psycopg/transform.pyx"
+include "_psycopg/waiting.pyx"
+
+include "types/array.pyx"
+include "types/datetime.pyx"
+include "types/numeric.pyx"
+include "types/bool.pyx"
+include "types/string.pyx"
diff --git a/psycopg_c/psycopg_c/_psycopg/__init__.pxd b/psycopg_c/psycopg_c/_psycopg/__init__.pxd
new file mode 100644
index 0000000..db22deb
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/__init__.pxd
@@ -0,0 +1,9 @@
+"""
+psycopg_c._psycopg cython module.
+
+This file is necessary to allow c-importing pxd files from this directory.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c._psycopg cimport oids
diff --git a/psycopg_c/psycopg_c/_psycopg/adapt.pyx b/psycopg_c/psycopg_c/_psycopg/adapt.pyx
new file mode 100644
index 0000000..a6d8e6a
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/adapt.pyx
@@ -0,0 +1,171 @@
+"""
+C implementation of the adaptation system.
+
+This module maps each Python adaptation function to a C adaptation function.
+Notice that C adaptation functions have a different signature because they can
+avoid making a memory copy, however this makes impossible to expose them to
+Python.
+
+This module exposes facilities to map the builtin adapters in python to
+equivalent C implementations.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any
+
+cimport cython
+
+from libc.string cimport memcpy, memchr
+from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
+from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING
+
+from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping
+
+from psycopg import errors as e
+from psycopg.pq.misc import error_message
+
+
+@cython.freelist(8)
+cdef class CDumper:
+
+ cdef readonly object cls
+ cdef pq.PGconn _pgconn
+
+ oid = oids.INVALID_OID
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+ self.cls = cls
+ conn = context.connection if context is not None else None
+ self._pgconn = conn.pgconn if conn is not None else None
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ """Store the Postgres representation *obj* into *rv* at *offset*
+
+ Return the number of bytes written to rv or -1 on Python exception.
+
+ Subclasses must implement this method. The `dump()` implementation
+ transforms the result of this method to a bytearray so that it can be
+ returned to Python.
+
+ The function interface allows C code to use this method automatically
+ to create larger buffers, e.g. for copy, composite objects, etc.
+
+ Implementation note: as you will always need to make sure that rv
+ has enough space to include what you want to dump, `ensure_size()`
+ might probably come handy.
+ """
+ raise NotImplementedError()
+
+ def dump(self, obj):
+ """Return the Postgres representation of *obj* as Python array of bytes"""
+ cdef rv = PyByteArray_FromStringAndSize("", 0)
+ cdef Py_ssize_t length = self.cdump(obj, rv, 0)
+ PyByteArray_Resize(rv, length)
+ return rv
+
+ def quote(self, obj):
+ cdef char *ptr
+ cdef char *ptr_out
+ cdef Py_ssize_t length
+
+ value = self.dump(obj)
+
+ if self._pgconn is not None:
+ esc = Escaping(self._pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = Escaping()
+ out = esc.escape_string(value)
+
+ _buffer_as_string_and_size(out, &ptr, &length)
+
+ if not memchr(ptr, b'\\', length):
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ PyByteArray_Resize(rv, length + 2) # Must include the quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b"'"
+ memcpy(ptr_out + 1, ptr, length)
+ ptr_out[length + 1] = b"'"
+ return rv
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ if esc.escape_string(b"\\") == b"\\":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
+ return rv
+
+ cpdef object get_key(self, object obj, object format):
+ return self.cls
+
+ cpdef object upgrade(self, object obj, object format):
+ return self
+
+ @staticmethod
+ cdef char *ensure_size(bytearray ba, Py_ssize_t offset, Py_ssize_t size) except NULL:
+ """
+ Grow *ba*, if necessary, to contains at least *size* bytes after *offset*
+
+ Return the pointer in the bytearray at *offset*, i.e. the place where
+ you want to write *size* bytes.
+ """
+ cdef Py_ssize_t curr_size = PyByteArray_GET_SIZE(ba)
+ cdef Py_ssize_t new_size = offset + size
+ if curr_size < new_size:
+ PyByteArray_Resize(ba, new_size)
+
+ return PyByteArray_AS_STRING(ba) + offset
+
+
+@cython.freelist(8)
+cdef class CLoader:
+ cdef public libpq.Oid oid
+ cdef pq.PGconn _pgconn
+
+ def __cinit__(self, int oid, context: Optional[AdaptContext] = None):
+ self.oid = oid
+ conn = context.connection if context is not None else None
+ self._pgconn = conn.pgconn if conn is not None else None
+
+ cdef object cload(self, const char *data, size_t length):
+ raise NotImplementedError()
+
+ def load(self, object data) -> Any:
+ cdef char *ptr
+ cdef Py_ssize_t length
+ _buffer_as_string_and_size(data, &ptr, &length)
+ return self.cload(ptr, length)
+
+
+cdef class _CRecursiveLoader(CLoader):
+
+ cdef Transformer _tx
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+ self._tx = Transformer.from_context(context)
diff --git a/psycopg_c/psycopg_c/_psycopg/copy.pyx b/psycopg_c/psycopg_c/_psycopg/copy.pyx
new file mode 100644
index 0000000..b943095
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/copy.pyx
@@ -0,0 +1,340 @@
+"""
+C optimised functions for the copy system.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.string cimport memcpy
+from libc.stdint cimport uint16_t, uint32_t, int32_t
+from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
+from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE
+from cpython.memoryview cimport PyMemoryView_FromObject
+
+from psycopg_c._psycopg cimport endian
+from psycopg_c.pq cimport ViewBuffer
+
+from psycopg import errors as e
+
+cdef int32_t _binary_null = -1
+
+
+def format_row_binary(
+ row: Sequence[Any], tx: Transformer, out: bytearray = None
+) -> bytearray:
+ """Convert a row of adapted data to the data to send for binary copy"""
+ cdef Py_ssize_t rowlen = len(row)
+ cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen)
+
+ cdef Py_ssize_t pos # offset in 'out' where to write
+ if out is None:
+ out = PyByteArray_FromStringAndSize("", 0)
+ pos = 0
+ else:
+ pos = PyByteArray_GET_SIZE(out)
+
+ # let's start from a nice chunk
+ # (larger than most fixed size; for variable ones, oh well, we'll resize it)
+ cdef char *target = CDumper.ensure_size(
+ out, pos, sizeof(berowlen) + 20 * rowlen)
+
+ # Write the number of fields as network-order 16 bits
+ memcpy(target, <void *>&berowlen, sizeof(berowlen))
+ pos += sizeof(berowlen)
+
+ cdef Py_ssize_t size
+ cdef uint32_t besize
+ cdef char *buf
+ cdef int i
+ cdef PyObject *fmt = <PyObject *>PG_BINARY
+ cdef PyObject *row_dumper
+
+ if not tx._row_dumpers:
+ tx._row_dumpers = PyList_New(rowlen)
+
+ dumpers = tx._row_dumpers
+
+ for i in range(rowlen):
+ item = row[i]
+ if item is None:
+ target = CDumper.ensure_size(out, pos, sizeof(_binary_null))
+ memcpy(target, <void *>&_binary_null, sizeof(_binary_null))
+ pos += sizeof(_binary_null)
+ continue
+
+ row_dumper = PyList_GET_ITEM(dumpers, i)
+ if not row_dumper:
+ row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ Py_INCREF(<object>row_dumper)
+ PyList_SET_ITEM(dumpers, i, <object>row_dumper)
+
+ if (<RowDumper>row_dumper).cdumper is not None:
+ # A cdumper can resize if necessary and copy in place
+ size = (<RowDumper>row_dumper).cdumper.cdump(
+ item, out, pos + sizeof(besize))
+ # Also add the size of the item, before the item
+ besize = endian.htobe32(<int32_t>size)
+ target = PyByteArray_AS_STRING(out) # might have been moved by cdump
+ memcpy(target + pos, <void *>&besize, sizeof(besize))
+ else:
+ # A Python dumper, gotta call it and extract its juices
+ b = PyObject_CallFunctionObjArgs(
+ (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
+ _buffer_as_string_and_size(b, &buf, &size)
+ target = CDumper.ensure_size(out, pos, size + sizeof(besize))
+ besize = endian.htobe32(<int32_t>size)
+ memcpy(target, <void *>&besize, sizeof(besize))
+ memcpy(target + sizeof(besize), buf, size)
+
+ pos += size + sizeof(besize)
+
+ # Resize to the final size
+ PyByteArray_Resize(out, pos)
+ return out
+
+
+def format_row_text(
+ row: Sequence[Any], tx: Transformer, out: bytearray = None
+) -> bytearray:
+ cdef Py_ssize_t pos # offset in 'out' where to write
+ if out is None:
+ out = PyByteArray_FromStringAndSize("", 0)
+ pos = 0
+ else:
+ pos = PyByteArray_GET_SIZE(out)
+
+ cdef Py_ssize_t rowlen = len(row)
+
+ if rowlen == 0:
+ PyByteArray_Resize(out, pos + 1)
+ out[pos] = b"\n"
+ return out
+
+ cdef Py_ssize_t size, tmpsize
+ cdef char *buf
+ cdef int i, j
+ cdef unsigned char *target
+ cdef int nesc = 0
+ cdef int with_tab
+ cdef PyObject *fmt = <PyObject *>PG_TEXT
+ cdef PyObject *row_dumper
+
+ for i in range(rowlen):
+ # Include the tab before the data, so it gets included in the resizes
+ with_tab = i > 0
+
+ item = row[i]
+ if item is None:
+ if with_tab:
+ target = <unsigned char *>CDumper.ensure_size(out, pos, 3)
+ memcpy(target, b"\t\\N", 3)
+ pos += 3
+ else:
+ target = <unsigned char *>CDumper.ensure_size(out, pos, 2)
+ memcpy(target, b"\\N", 2)
+ pos += 2
+ continue
+
+ row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ if (<RowDumper>row_dumper).cdumper is not None:
+ # A cdumper can resize if necessary and copy in place
+ size = (<RowDumper>row_dumper).cdumper.cdump(
+ item, out, pos + with_tab)
+ target = <unsigned char *>PyByteArray_AS_STRING(out) + pos
+ else:
+ # A Python dumper, gotta call it and extract its juices
+ b = PyObject_CallFunctionObjArgs(
+ (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
+ _buffer_as_string_and_size(b, &buf, &size)
+ target = <unsigned char *>CDumper.ensure_size(out, pos, size + with_tab)
+ memcpy(target + with_tab, buf, size)
+
+ # Prepend a tab to the data just written
+ if with_tab:
+ target[0] = b"\t"
+ target += 1
+ pos += 1
+
+ # Now from pos to pos + size there is a textual representation: it may
+ # contain chars to escape. Scan to find how many such chars there are.
+ for j in range(size):
+ if copy_escape_lut[target[j]]:
+ nesc += 1
+
+ # If there is any char to escape, walk backwards pushing the chars
+ # forward and interspersing backslashes.
+ if nesc > 0:
+ tmpsize = size + nesc
+ target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize)
+ for j in range(<int>size - 1, -1, -1):
+ if copy_escape_lut[target[j]]:
+ target[j + nesc] = copy_escape_lut[target[j]]
+ nesc -= 1
+ target[j + nesc] = b"\\"
+ if nesc <= 0:
+ break
+ else:
+ target[j + nesc] = target[j]
+ pos += tmpsize
+ else:
+ pos += size
+
+ # Resize to the final size, add the newline
+ PyByteArray_Resize(out, pos + 1)
+ out[pos] = b"\n"
+ return out
+
+
+def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *ptr
+ cdef Py_ssize_t bufsize
+ _buffer_as_string_and_size(data, <char **>&ptr, &bufsize)
+ cdef unsigned char *bufend = ptr + bufsize
+
+ cdef uint16_t benfields = (<uint16_t *>ptr)[0]
+ cdef int nfields = endian.be16toh(benfields)
+ ptr += sizeof(benfields)
+ cdef list row = PyList_New(nfields)
+
+ cdef int col
+ cdef int32_t belength
+ cdef Py_ssize_t length
+
+ for col in range(nfields):
+ memcpy(&belength, ptr, sizeof(belength))
+ ptr += sizeof(belength)
+ if belength == _binary_null:
+ field = None
+ else:
+ length = endian.be32toh(belength)
+ if ptr + length > bufend:
+ raise e.DataError("bad copy data: length exceeding data")
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(data, ptr, length))
+ ptr += length
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ return tx.load_sequence(row)
+
+
+def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]:
+ cdef unsigned char *fstart
+ cdef Py_ssize_t size
+ _buffer_as_string_and_size(data, <char **>&fstart, &size)
+
+ # politely assume that the number of fields will be what in the result
+ cdef int nfields = tx._nfields
+ cdef list row = PyList_New(nfields)
+
+ cdef unsigned char *fend
+ cdef unsigned char *rowend = fstart + size
+ cdef unsigned char *src
+ cdef unsigned char *tgt
+ cdef int col
+ cdef int num_bs
+
+ for col in range(nfields):
+ fend = fstart
+ num_bs = 0
+ # Scan to the end of the field, remember if you see any backslash
+ while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend:
+ if fend[0] == b'\\':
+ num_bs += 1
+ # skip the next char to avoid counting escaped backslashes twice
+ fend += 1
+ fend += 1
+
+ # Check if we stopped for the right reason
+ if fend >= rowend:
+ raise e.DataError("bad copy data: field delimiter not found")
+ elif fend[0] == b'\t' and col == nfields - 1:
+ raise e.DataError("bad copy data: got a tab at the end of the row")
+ elif fend[0] == b'\n' and col != nfields - 1:
+ raise e.DataError(
+ "bad copy format: got a newline before the end of the row")
+
+ # Is this a NULL?
+ if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N':
+ field = None
+
+ # Is this a field with no backslash?
+ elif num_bs == 0:
+ # Nothing to unescape: we don't need a copy
+ field = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(data, fstart, fend - fstart))
+
+ # This is a field containing backslashes
+ else:
+ # We need a copy of the buffer to unescape
+ field = PyByteArray_FromStringAndSize("", 0)
+ PyByteArray_Resize(field, fend - fstart - num_bs)
+ tgt = <unsigned char *>PyByteArray_AS_STRING(field)
+ src = fstart
+ while (src < fend):
+ if src[0] != b'\\':
+ tgt[0] = src[0]
+ else:
+ src += 1
+ tgt[0] = copy_unescape_lut[src[0]]
+ src += 1
+ tgt += 1
+
+ Py_INCREF(field)
+ PyList_SET_ITEM(row, col, field)
+
+ # Start of the field
+ fstart = fend + 1
+
+ # Convert the array of buffers into Python objects
+ return tx.load_sequence(row)
+
+
+cdef extern from *:
+ """
+/* handle chars to (un)escape in text copy representation */
+/* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */
+
+/* Escaping chars */
+static const char copy_escape_lut[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 98, 116, 110, 118, 102, 114, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 92, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+/* Conversion of escaped to unescaped chars */
+static const char copy_unescape_lut[] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
+ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
+ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
+ 96, 97, 8, 99, 100, 101, 12, 103, 104, 105, 106, 107, 108, 109, 10, 111,
+112, 113, 13, 115, 9, 117, 11, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
+144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
+160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
+176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
+192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
+208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
+240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
+};
+ """
+ const char[256] copy_escape_lut
+ const char[256] copy_unescape_lut
diff --git a/psycopg_c/psycopg_c/_psycopg/endian.pxd b/psycopg_c/psycopg_c/_psycopg/endian.pxd
new file mode 100644
index 0000000..44e7305
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/endian.pxd
@@ -0,0 +1,155 @@
+"""
+Access to endian conversion function
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.stdint cimport uint16_t, uint32_t, uint64_t
+
+cdef extern from * nogil:
+ # from https://gist.github.com/panzi/6856583
+ # Improved in:
+ # https://github.com/linux-sunxi/sunxi-tools/blob/master/include/portable_endian.h
+ """
+// "License": Public Domain
+// I, Mathias Panzenböck, place this file hereby into the public domain. Use it at your own risk for whatever you like.
+// In case there are jurisdictions that don't support putting things in the public domain you can also consider it to
+// be "dual licensed" under the BSD, MIT and Apache licenses, if you want to. This code is trivial anyway. Consider it
+// an example on how to get the endian conversion functions on different platforms.
+
+#ifndef PORTABLE_ENDIAN_H__
+#define PORTABLE_ENDIAN_H__
+
+#if (defined(_WIN16) || defined(_WIN32) || defined(_WIN64)) && !defined(__WINDOWS__)
+
+# define __WINDOWS__
+
+#endif
+
+#if defined(__linux__) || defined(__CYGWIN__)
+
+# include <endian.h>
+
+#elif defined(__APPLE__)
+
+# include <libkern/OSByteOrder.h>
+
+# define htobe16(x) OSSwapHostToBigInt16(x)
+# define htole16(x) OSSwapHostToLittleInt16(x)
+# define be16toh(x) OSSwapBigToHostInt16(x)
+# define le16toh(x) OSSwapLittleToHostInt16(x)
+
+# define htobe32(x) OSSwapHostToBigInt32(x)
+# define htole32(x) OSSwapHostToLittleInt32(x)
+# define be32toh(x) OSSwapBigToHostInt32(x)
+# define le32toh(x) OSSwapLittleToHostInt32(x)
+
+# define htobe64(x) OSSwapHostToBigInt64(x)
+# define htole64(x) OSSwapHostToLittleInt64(x)
+# define be64toh(x) OSSwapBigToHostInt64(x)
+# define le64toh(x) OSSwapLittleToHostInt64(x)
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#elif defined(__OpenBSD__) || defined(__NetBSD__) || defined(__FreeBSD__) || defined(__DragonFly__)
+
+# include <sys/endian.h>
+
+/* For functions still missing, try to substitute 'historic' OpenBSD names */
+#ifndef be16toh
+# define be16toh(x) betoh16(x)
+#endif
+#ifndef le16toh
+# define le16toh(x) letoh16(x)
+#endif
+#ifndef be32toh
+# define be32toh(x) betoh32(x)
+#endif
+#ifndef le32toh
+# define le32toh(x) letoh32(x)
+#endif
+#ifndef be64toh
+# define be64toh(x) betoh64(x)
+#endif
+#ifndef le64toh
+# define le64toh(x) letoh64(x)
+#endif
+
+#elif defined(__WINDOWS__)
+
+# include <winsock2.h>
+# ifndef _MSC_VER
+# include <sys/param.h>
+# endif
+
+# if BYTE_ORDER == LITTLE_ENDIAN
+
+# define htobe16(x) htons(x)
+# define htole16(x) (x)
+# define be16toh(x) ntohs(x)
+# define le16toh(x) (x)
+
+# define htobe32(x) htonl(x)
+# define htole32(x) (x)
+# define be32toh(x) ntohl(x)
+# define le32toh(x) (x)
+
+# define htobe64(x) htonll(x)
+# define htole64(x) (x)
+# define be64toh(x) ntohll(x)
+# define le64toh(x) (x)
+
+# elif BYTE_ORDER == BIG_ENDIAN
+
+ /* that would be xbox 360 */
+# define htobe16(x) (x)
+# define htole16(x) __builtin_bswap16(x)
+# define be16toh(x) (x)
+# define le16toh(x) __builtin_bswap16(x)
+
+# define htobe32(x) (x)
+# define htole32(x) __builtin_bswap32(x)
+# define be32toh(x) (x)
+# define le32toh(x) __builtin_bswap32(x)
+
+# define htobe64(x) (x)
+# define htole64(x) __builtin_bswap64(x)
+# define be64toh(x) (x)
+# define le64toh(x) __builtin_bswap64(x)
+
+# else
+
+# error byte order not supported
+
+# endif
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#else
+
+# error platform not supported
+
+#endif
+
+#endif
+ """
+ cdef uint16_t htobe16(uint16_t host_16bits)
+ cdef uint16_t htole16(uint16_t host_16bits)
+ cdef uint16_t be16toh(uint16_t big_endian_16bits)
+ cdef uint16_t le16toh(uint16_t little_endian_16bits)
+
+ cdef uint32_t htobe32(uint32_t host_32bits)
+ cdef uint32_t htole32(uint32_t host_32bits)
+ cdef uint32_t be32toh(uint32_t big_endian_32bits)
+ cdef uint32_t le32toh(uint32_t little_endian_32bits)
+
+ cdef uint64_t htobe64(uint64_t host_64bits)
+ cdef uint64_t htole64(uint64_t host_64bits)
+ cdef uint64_t be64toh(uint64_t big_endian_64bits)
+ cdef uint64_t le64toh(uint64_t little_endian_64bits)
diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx
new file mode 100644
index 0000000..9ce9e54
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx
@@ -0,0 +1,276 @@
+"""
+C implementation of generators for the communication protocols with the libpq
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from cpython.object cimport PyObject_CallFunctionObjArgs
+
+from typing import List
+
+from psycopg import errors as e
+from psycopg.pq import abc, error_message
+from psycopg.abc import PipelineCommand, PQGen
+from psycopg._enums import Wait, Ready
+from psycopg._compat import Deque
+from psycopg._encodings import conninfo_encoding
+
+cdef object WAIT_W = Wait.W
+cdef object WAIT_R = Wait.R
+cdef object WAIT_RW = Wait.RW
+cdef object PY_READY_R = Ready.R
+cdef object PY_READY_W = Ready.W
+cdef object PY_READY_RW = Ready.RW
+cdef int READY_R = Ready.R
+cdef int READY_W = Ready.W
+cdef int READY_RW = Ready.RW
+
+def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
+ """
+ Generator to create a database connection without blocking.
+
+ """
+ cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode())
+ cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr
+ cdef int conn_status = libpq.PQstatus(pgconn_ptr)
+ cdef int poll_status
+
+ while True:
+ if conn_status == libpq.CONNECTION_BAD:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection is bad: {error_message(conn, encoding=encoding)}",
+ pgconn=conn
+ )
+
+ with nogil:
+ poll_status = libpq.PQconnectPoll(pgconn_ptr)
+
+ if poll_status == libpq.PGRES_POLLING_OK:
+ break
+ elif poll_status == libpq.PGRES_POLLING_READING:
+ yield (libpq.PQsocket(pgconn_ptr), WAIT_R)
+ elif poll_status == libpq.PGRES_POLLING_WRITING:
+ yield (libpq.PQsocket(pgconn_ptr), WAIT_W)
+ elif poll_status == libpq.PGRES_POLLING_FAILED:
+ encoding = conninfo_encoding(conninfo)
+ raise e.OperationalError(
+ f"connection failed: {error_message(conn, encoding=encoding)}",
+ pgconn=conn
+ )
+ else:
+ raise e.InternalError(
+ f"unexpected poll status: {poll_status}", pgconn=conn
+ )
+
+ conn.nonblocking = 1
+ return conn
+
+
+def execute(pq.PGconn pgconn) -> PQGen[List[abc.PGresult]]:
+ """
+ Generator sending a query and returning results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ yield from send(pgconn)
+ rv = yield from fetch_many(pgconn)
+ return rv
+
+
+def send(pq.PGconn pgconn) -> PQGen[None]:
+ """
+ Generator to send a query to the server without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ After this generator has finished you may want to cycle using `fetch()`
+ to retrieve the results available.
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int status
+ cdef int cires
+
+ while True:
+ if pgconn.flush() == 0:
+ break
+
+ status = yield WAIT_RW
+ if status & READY_R:
+ with nogil:
+ # This call may read notifies which will be saved in the
+ # PGconn buffer and passed to Python later.
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+
+
+def fetch_many(pq.PGconn pgconn) -> PQGen[List[PGresult]]:
+ """
+ Generator retrieving results from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ cdef list results = []
+ cdef int status
+ cdef pq.PGresult result
+ cdef libpq.PGresult *pgres
+
+ while True:
+ result = yield from fetch(pgconn)
+ if result is None:
+ break
+ results.append(result)
+ pgres = result._pgresult_ptr
+
+ status = libpq.PQresultStatus(pgres)
+ if (
+ status == libpq.PGRES_COPY_IN
+ or status == libpq.PGRES_COPY_OUT
+ or status == libpq.PGRES_COPY_BOTH
+ ):
+ # After entering copy mode the libpq will create a phony result
+ # for every request so let's break the endless loop.
+ break
+
+ if status == libpq.PGRES_PIPELINE_SYNC:
+ # PIPELINE_SYNC is not followed by a NULL, but we return it alone
+ # similarly to other result sets.
+ break
+
+ return results
+
+
+def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
+ """
+ Generator retrieving a single result from the database without blocking.
+
+ The query must have already been sent to the server, so pgconn.flush() has
+ already returned 0.
+
+ Return a result from the database (whether success or error).
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int cires, ibres
+ cdef libpq.PGresult *pgres
+
+ with nogil:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+ if ibres:
+ yield WAIT_R
+ while True:
+ with nogil:
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if cires == 1:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+ if not ibres:
+ break
+ yield WAIT_R
+
+ _consume_notifies(pgconn)
+
+ with nogil:
+ pgres = libpq.PQgetResult(pgconn_ptr)
+ if pgres is NULL:
+ return None
+ return pq.PGresult._from_ptr(pgres)
+
+
+def pipeline_communicate(
+ pq.PGconn pgconn, commands: Deque[PipelineCommand]
+) -> PQGen[List[List[PGresult]]]:
+ """Generator to send queries from a connection in pipeline mode while also
+ receiving results.
+
+ Return a list results, including single PIPELINE_SYNC elements.
+ """
+ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
+ cdef int cires
+ cdef int status
+ cdef int ready
+ cdef libpq.PGresult *pgres
+ cdef list res = []
+ cdef list results = []
+ cdef pq.PGresult r
+
+ while True:
+ ready = yield WAIT_RW
+
+ if ready & READY_R:
+ with nogil:
+ cires = libpq.PQconsumeInput(pgconn_ptr)
+ if 1 != cires:
+ raise e.OperationalError(
+ f"consuming input failed: {error_message(pgconn)}")
+
+ _consume_notifies(pgconn)
+
+ res: List[PGresult] = []
+ while True:
+ with nogil:
+ ibres = libpq.PQisBusy(pgconn_ptr)
+ if ibres:
+ break
+ pgres = libpq.PQgetResult(pgconn_ptr)
+
+ if pgres is NULL:
+ if not res:
+ break
+ results.append(res)
+ res = []
+ else:
+ status = libpq.PQresultStatus(pgres)
+ r = pq.PGresult._from_ptr(pgres)
+ if status == libpq.PGRES_PIPELINE_SYNC:
+ results.append([r])
+ break
+ else:
+ res.append(r)
+
+ if ready & READY_W:
+ pgconn.flush()
+ if not commands:
+ break
+ commands.popleft()()
+
+ return results
+
+
+cdef int _consume_notifies(pq.PGconn pgconn) except -1:
+ cdef object notify_handler = pgconn.notify_handler
+ cdef libpq.PGconn *pgconn_ptr
+ cdef libpq.PGnotify *notify
+
+ if notify_handler is not None:
+ while True:
+ pynotify = pgconn.notifies()
+ if pynotify is None:
+ break
+ PyObject_CallFunctionObjArgs(
+ notify_handler, <PyObject *>pynotify, NULL
+ )
+ else:
+ pgconn_ptr = pgconn._pgconn_ptr
+ while True:
+ notify = libpq.PQnotifies(pgconn_ptr)
+ if notify is NULL:
+ break
+ libpq.PQfreemem(notify)
+
+ return 0
diff --git a/psycopg_c/psycopg_c/_psycopg/oids.pxd b/psycopg_c/psycopg_c/_psycopg/oids.pxd
new file mode 100644
index 0000000..2a864c4
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/oids.pxd
@@ -0,0 +1,92 @@
+"""
+Constants to refer to OIDS in C
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use tools/update_oids.py to update this data.
+
+cdef enum:
+ INVALID_OID = 0
+
+ # autogenerated: start
+
+ # Generated from PostgreSQL 15.0
+
+ ACLITEM_OID = 1033
+ BIT_OID = 1560
+ BOOL_OID = 16
+ BOX_OID = 603
+ BPCHAR_OID = 1042
+ BYTEA_OID = 17
+ CHAR_OID = 18
+ CID_OID = 29
+ CIDR_OID = 650
+ CIRCLE_OID = 718
+ DATE_OID = 1082
+ DATEMULTIRANGE_OID = 4535
+ DATERANGE_OID = 3912
+ FLOAT4_OID = 700
+ FLOAT8_OID = 701
+ GTSVECTOR_OID = 3642
+ INET_OID = 869
+ INT2_OID = 21
+ INT2VECTOR_OID = 22
+ INT4_OID = 23
+ INT4MULTIRANGE_OID = 4451
+ INT4RANGE_OID = 3904
+ INT8_OID = 20
+ INT8MULTIRANGE_OID = 4536
+ INT8RANGE_OID = 3926
+ INTERVAL_OID = 1186
+ JSON_OID = 114
+ JSONB_OID = 3802
+ JSONPATH_OID = 4072
+ LINE_OID = 628
+ LSEG_OID = 601
+ MACADDR_OID = 829
+ MACADDR8_OID = 774
+ MONEY_OID = 790
+ NAME_OID = 19
+ NUMERIC_OID = 1700
+ NUMMULTIRANGE_OID = 4532
+ NUMRANGE_OID = 3906
+ OID_OID = 26
+ OIDVECTOR_OID = 30
+ PATH_OID = 602
+ PG_LSN_OID = 3220
+ POINT_OID = 600
+ POLYGON_OID = 604
+ RECORD_OID = 2249
+ REFCURSOR_OID = 1790
+ REGCLASS_OID = 2205
+ REGCOLLATION_OID = 4191
+ REGCONFIG_OID = 3734
+ REGDICTIONARY_OID = 3769
+ REGNAMESPACE_OID = 4089
+ REGOPER_OID = 2203
+ REGOPERATOR_OID = 2204
+ REGPROC_OID = 24
+ REGPROCEDURE_OID = 2202
+ REGROLE_OID = 4096
+ REGTYPE_OID = 2206
+ TEXT_OID = 25
+ TID_OID = 27
+ TIME_OID = 1083
+ TIMESTAMP_OID = 1114
+ TIMESTAMPTZ_OID = 1184
+ TIMETZ_OID = 1266
+ TSMULTIRANGE_OID = 4533
+ TSQUERY_OID = 3615
+ TSRANGE_OID = 3908
+ TSTZMULTIRANGE_OID = 4534
+ TSTZRANGE_OID = 3910
+ TSVECTOR_OID = 3614
+ TXID_SNAPSHOT_OID = 2970
+ UUID_OID = 2950
+ VARBIT_OID = 1562
+ VARCHAR_OID = 1043
+ XID_OID = 28
+ XID8_OID = 5069
+ XML_OID = 142
+ # autogenerated: end
diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx
new file mode 100644
index 0000000..fc69725
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx
@@ -0,0 +1,640 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+
+Cython implementation: can access to lower level C features without creating
+too many temporary Python objects and performing less memory copying.
+
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.ref cimport Py_INCREF, Py_DECREF
+from cpython.set cimport PySet_Add, PySet_Contains
+from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
+from cpython.list cimport (
+ PyList_New, PyList_CheckExact,
+ PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE)
+from cpython.bytes cimport PyBytes_AS_STRING
+from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
+from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
+
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
+
+from psycopg import errors as e
+from psycopg.pq import Format as PqFormat
+from psycopg.rows import Row, RowMaker
+from psycopg._encodings import pgconn_encoding
+
+NoneType = type(None)
+
+# internal structure: you are not supposed to know this. But it's worth some
+# 10% of the innermost loop, so I'm willing to ask for forgiveness later...
+
+ctypedef struct PGresAttValue:
+ int len
+ char *value
+
+ctypedef struct pg_result_int:
+ # NOTE: it would be advised that we don't know this structure's content
+ int ntups
+ int numAttributes
+ libpq.PGresAttDesc *attDescs
+ PGresAttValue **tuples
+ # ...more members, which we ignore
+
+
+@cython.freelist(16)
+cdef class RowLoader:
+ cdef CLoader cloader
+ cdef object pyloader
+ cdef object loadfunc
+
+
+@cython.freelist(16)
+cdef class RowDumper:
+ cdef CDumper cdumper
+ cdef object pydumper
+ cdef object dumpfunc
+ cdef object oid
+ cdef object format
+
+
+cdef class Transformer:
+ """
+ An object that can adapt efficiently between Python and PostgreSQL.
+
+ The life cycle of the object is the query, so it is assumed that attributes
+ such as the server version or the connection encoding will not change. The
+ object have its state so adapting several values of the same type can be
+ optimised.
+
+ """
+
+ cdef readonly object connection
+ cdef readonly object adapters
+ cdef readonly object types
+ cdef readonly object formats
+ cdef str _encoding
+ cdef int _none_oid
+
+ # mapping class -> Dumper instance (auto, text, binary)
+ cdef dict _auto_dumpers
+ cdef dict _text_dumpers
+ cdef dict _binary_dumpers
+
+ # mapping oid -> Loader instance (text, binary)
+ cdef dict _text_loaders
+ cdef dict _binary_loaders
+
+ # mapping oid -> Dumper instance (text, binary)
+ cdef dict _oid_text_dumpers
+ cdef dict _oid_binary_dumpers
+
+ cdef pq.PGresult _pgresult
+ cdef int _nfields, _ntuples
+ cdef list _row_dumpers
+ cdef list _row_loaders
+
+ cdef dict _oid_types
+
+ def __cinit__(self, context: Optional["AdaptContext"] = None):
+ if context is not None:
+ self.adapters = context.adapters
+ self.connection = context.connection
+ else:
+ from psycopg import postgres
+ self.adapters = postgres.adapters
+ self.connection = None
+
+ self.types = self.formats = None
+ self._none_oid = -1
+
+ @classmethod
+ def from_context(cls, context: Optional["AdaptContext"]):
+ """
+ Return a Transformer from an AdaptContext.
+
+ If the context is a Transformer instance, just return it.
+ """
+ return _tx_from_context(context)
+
+ @property
+ def encoding(self) -> str:
+ if not self._encoding:
+ conn = self.connection
+ self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+ return self._encoding
+
+ @property
+ def pgresult(self) -> Optional[PGresult]:
+ return self._pgresult
+
+ cpdef set_pgresult(
+ self,
+ pq.PGresult result,
+ object set_loaders = True,
+ object format = None
+ ):
+ self._pgresult = result
+
+ if result is None:
+ self._nfields = self._ntuples = 0
+ if set_loaders:
+ self._row_loaders = []
+ return
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ self._nfields = libpq.PQnfields(res)
+ self._ntuples = libpq.PQntuples(res)
+
+ if not set_loaders:
+ return
+
+ if not self._nfields:
+ self._row_loaders = []
+ return
+
+ if format is None:
+ format = libpq.PQfformat(res, 0)
+
+ cdef list loaders = PyList_New(self._nfields)
+ cdef PyObject *row_loader
+ cdef object oid
+
+ cdef int i
+ for i in range(self._nfields):
+ oid = libpq.PQftype(res, i)
+ row_loader = self._c_get_loader(<PyObject *>oid, <PyObject *>format)
+ Py_INCREF(<object>row_loader)
+ PyList_SET_ITEM(loaders, i, <object>row_loader)
+
+ self._row_loaders = loaders
+
+ def set_dumper_types(self, types: Sequence[int], format: Format) -> None:
+ cdef Py_ssize_t ntypes = len(types)
+ dumpers = PyList_New(ntypes)
+ cdef int i
+ for i in range(ntypes):
+ oid = types[i]
+ dumper_ptr = self.get_dumper_by_oid(
+ <PyObject *>oid, <PyObject *>format)
+ Py_INCREF(<object>dumper_ptr)
+ PyList_SET_ITEM(dumpers, i, <object>dumper_ptr)
+
+ self._row_dumpers = dumpers
+ self.types = tuple(types)
+ self.formats = [format] * ntypes
+
+ def set_loader_types(self, types: Sequence[int], format: Format) -> None:
+ self._c_loader_types(len(types), types, format)
+
+ cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, object format):
+ cdef list loaders = PyList_New(ntypes)
+
+ # these are used more as Python object than C
+ cdef PyObject *oid
+ cdef PyObject *row_loader
+ for i in range(ntypes):
+ oid = PyList_GET_ITEM(types, i)
+ row_loader = self._c_get_loader(oid, <PyObject *>format)
+ Py_INCREF(<object>row_loader)
+ PyList_SET_ITEM(loaders, i, <object>row_loader)
+
+ self._row_loaders = loaders
+
+ cpdef as_literal(self, obj):
+ cdef PyObject *row_dumper = self.get_row_dumper(
+ <PyObject *>obj, <PyObject *>PG_TEXT)
+
+ if (<RowDumper>row_dumper).cdumper is not None:
+ dumper = (<RowDumper>row_dumper).cdumper
+ else:
+ dumper = (<RowDumper>row_dumper).pydumper
+
+ rv = dumper.quote(obj)
+ oid = dumper.oid
+ # If the result is quoted and the oid not unknown or text,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ if oid and oid != oids.TEXT_OID and rv and rv[-1] == 39:
+ if self._oid_types is None:
+ self._oid_types = {}
+ type_ptr = PyDict_GetItem(<object>self._oid_types, oid)
+ if type_ptr == NULL:
+ type_sql = b""
+ ti = self.adapters.types.get(oid)
+ if ti is not None:
+ if oid < 8192:
+ # builtin: prefer "timestamptz" to "timestamp with time zone"
+ type_sql = ti.name.encode(self.encoding)
+ else:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+
+ type_ptr = <PyObject *>type_sql
+ PyDict_SetItem(<object>self._oid_types, oid, type_sql)
+
+ if <object>type_ptr:
+ rv = b"%s::%s" % (rv, <object>type_ptr)
+
+ return rv
+
+ def get_dumper(self, obj, format) -> "Dumper":
+ cdef PyObject *row_dumper = self.get_row_dumper(
+ <PyObject *>obj, <PyObject *>format)
+ return (<RowDumper>row_dumper).pydumper
+
+ cdef PyObject *get_row_dumper(self, PyObject *obj, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowDumper for the given obj/fmt.
+ """
+ # Fast path: return a Dumper class already instantiated from the same type
+ cdef PyObject *cache
+ cdef PyObject *ptr
+ cdef PyObject *ptr1
+ cdef RowDumper row_dumper
+
+ # Normally, the type of the object dictates how to dump it
+ key = type(<object>obj)
+
+ # Establish where would the dumper be cached
+ bfmt = PyUnicode_AsUTF8String(<object>fmt)
+ cdef char cfmt = PyBytes_AS_STRING(bfmt)[0]
+ if cfmt == b's':
+ if self._auto_dumpers is None:
+ self._auto_dumpers = {}
+ cache = <PyObject *>self._auto_dumpers
+ elif cfmt == b'b':
+ if self._binary_dumpers is None:
+ self._binary_dumpers = {}
+ cache = <PyObject *>self._binary_dumpers
+ elif cfmt == b't':
+ if self._text_dumpers is None:
+ self._text_dumpers = {}
+ cache = <PyObject *>self._text_dumpers
+ else:
+ raise ValueError(
+ f"format should be a psycopg.adapt.Format, not {<object>fmt}")
+
+ # Reuse an existing Dumper class for objects of the same type
+ ptr = PyDict_GetItem(<object>cache, key)
+ if ptr == NULL:
+ dcls = PyObject_CallFunctionObjArgs(
+ self.adapters.get_dumper, <PyObject *>key, fmt, NULL)
+ dumper = PyObject_CallFunctionObjArgs(
+ dcls, <PyObject *>key, <PyObject *>self, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, key, row_dumper)
+ ptr = <PyObject *>row_dumper
+
+ # Check if the dumper requires an upgrade to handle this specific value
+ if (<RowDumper>ptr).cdumper is not None:
+ key1 = (<RowDumper>ptr).cdumper.get_key(<object>obj, <object>fmt)
+ else:
+ key1 = PyObject_CallFunctionObjArgs(
+ (<RowDumper>ptr).pydumper.get_key, obj, fmt, NULL)
+ if key1 is key:
+ return ptr
+
+ # If it does, ask the dumper to create its own upgraded version
+ ptr1 = PyDict_GetItem(<object>cache, key1)
+ if ptr1 != NULL:
+ return ptr1
+
+ if (<RowDumper>ptr).cdumper is not None:
+ dumper = (<RowDumper>ptr).cdumper.upgrade(<object>obj, <object>fmt)
+ else:
+ dumper = PyObject_CallFunctionObjArgs(
+ (<RowDumper>ptr).pydumper.upgrade, obj, fmt, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, key1, row_dumper)
+ return <PyObject *>row_dumper
+
+ cdef PyObject *get_dumper_by_oid(self, PyObject *oid, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowDumper for the given oid/fmt.
+ """
+ cdef PyObject *ptr
+ cdef PyObject *cache
+ cdef RowDumper row_dumper
+
+ # Establish where would the dumper be cached
+ cdef int cfmt = <object>fmt
+ if cfmt == 0:
+ if self._oid_text_dumpers is None:
+ self._oid_text_dumpers = {}
+ cache = <PyObject *>self._oid_text_dumpers
+ elif cfmt == 1:
+ if self._oid_binary_dumpers is None:
+ self._oid_binary_dumpers = {}
+ cache = <PyObject *>self._oid_binary_dumpers
+ else:
+ raise ValueError(
+ f"format should be a psycopg.pq.Format, not {<object>fmt}")
+
+ # Reuse an existing Dumper class for objects of the same type
+ ptr = PyDict_GetItem(<object>cache, <object>oid)
+ if ptr == NULL:
+ dcls = PyObject_CallFunctionObjArgs(
+ self.adapters.get_dumper_by_oid, oid, fmt, NULL)
+ dumper = PyObject_CallFunctionObjArgs(
+ dcls, <PyObject *>NoneType, <PyObject *>self, NULL)
+
+ row_dumper = _as_row_dumper(dumper)
+ PyDict_SetItem(<object>cache, <object>oid, row_dumper)
+ ptr = <PyObject *>row_dumper
+
+ return ptr
+
+ cpdef dump_sequence(self, object params, object formats):
+ # Verify that they are not none and that PyList_GET_ITEM won't blow up
+ cdef Py_ssize_t nparams = len(params)
+ cdef list out = PyList_New(nparams)
+
+ cdef int i
+ cdef PyObject *dumper_ptr # borrowed pointer to row dumper
+ cdef object dumped
+ cdef Py_ssize_t size
+
+ if self._none_oid < 0:
+ self._none_oid = self.adapters.get_dumper(NoneType, "s").oid
+
+ dumpers = self._row_dumpers
+
+ if dumpers:
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ dumper_ptr = PyList_GET_ITEM(dumpers, i)
+ if (<RowDumper>dumper_ptr).cdumper is not None:
+ dumped = PyByteArray_FromStringAndSize("", 0)
+ size = (<RowDumper>dumper_ptr).cdumper.cdump(
+ param, <bytearray>dumped, 0)
+ PyByteArray_Resize(dumped, size)
+ else:
+ dumped = PyObject_CallFunctionObjArgs(
+ (<RowDumper>dumper_ptr).dumpfunc,
+ <PyObject *>param, NULL)
+ else:
+ dumped = None
+
+ Py_INCREF(dumped)
+ PyList_SET_ITEM(out, i, dumped)
+
+ return out
+
+ cdef tuple types = PyTuple_New(nparams)
+ cdef list pqformats = PyList_New(nparams)
+
+ for i in range(nparams):
+ param = params[i]
+ if param is not None:
+ dumper_ptr = self.get_row_dumper(
+ <PyObject *>param, <PyObject *>formats[i])
+ if (<RowDumper>dumper_ptr).cdumper is not None:
+ dumped = PyByteArray_FromStringAndSize("", 0)
+ size = (<RowDumper>dumper_ptr).cdumper.cdump(
+ param, <bytearray>dumped, 0)
+ PyByteArray_Resize(dumped, size)
+ else:
+ dumped = PyObject_CallFunctionObjArgs(
+ (<RowDumper>dumper_ptr).dumpfunc,
+ <PyObject *>param, NULL)
+ oid = (<RowDumper>dumper_ptr).oid
+ fmt = (<RowDumper>dumper_ptr).format
+ else:
+ dumped = None
+ oid = self._none_oid
+ fmt = PQ_TEXT
+
+ Py_INCREF(dumped)
+ PyList_SET_ITEM(out, i, dumped)
+
+ Py_INCREF(oid)
+ PyTuple_SET_ITEM(types, i, oid)
+
+ Py_INCREF(fmt)
+ PyList_SET_ITEM(pqformats, i, fmt)
+
+ self.types = types
+ self.formats = pqformats
+ return out
+
+ def load_rows(self, int row0, int row1, object make_row) -> List[Row]:
+ if self._pgresult is None:
+ raise e.InterfaceError("result not set")
+
+ if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
+ raise e.InterfaceError(
+ f"rows must be included between 0 and {self._ntuples}"
+ )
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ # cheeky access to the internal PGresult structure
+ cdef pg_result_int *ires = <pg_result_int*>res
+
+ cdef int row
+ cdef int col
+ cdef PGresAttValue *attval
+ cdef object record # not 'tuple' as it would check on assignment
+
+ cdef object records = PyList_New(row1 - row0)
+ for row in range(row0, row1):
+ record = PyTuple_New(self._nfields)
+ Py_INCREF(record)
+ PyList_SET_ITEM(records, row - row0, record)
+
+ cdef PyObject *loader # borrowed RowLoader
+ cdef PyObject *brecord # borrowed
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+
+ for col in range(self._nfields):
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ for row in range(row0, row1):
+ brecord = PyList_GET_ITEM(records, row - row0)
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ pyval = (<RowLoader>loader).cloader.cload(
+ attval.value, attval.len)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(<object>brecord, col, pyval)
+
+ else:
+ for row in range(row0, row1):
+ brecord = PyList_GET_ITEM(records, row - row0)
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(<object>brecord, col, pyval)
+
+ if make_row is not tuple:
+ for i in range(row1 - row0):
+ brecord = PyList_GET_ITEM(records, i)
+ record = PyObject_CallFunctionObjArgs(
+ make_row, <PyObject *>brecord, NULL)
+ Py_INCREF(record)
+ PyList_SET_ITEM(records, i, record)
+ Py_DECREF(<object>brecord)
+ return records
+
+ def load_row(self, int row, object make_row) -> Optional[Row]:
+ if self._pgresult is None:
+ return None
+
+ if not 0 <= row < self._ntuples:
+ return None
+
+ cdef libpq.PGresult *res = self._pgresult._pgresult_ptr
+ # cheeky access to the internal PGresult structure
+ cdef pg_result_int *ires = <pg_result_int*>res
+
+ cdef PyObject *loader # borrowed RowLoader
+ cdef int col
+ cdef PGresAttValue *attval
+ cdef object record # not 'tuple' as it would check on assignment
+
+ record = PyTuple_New(self._nfields)
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+
+ for col in range(self._nfields):
+ attval = &(ires.tuples[row][col])
+ if attval.len == -1: # NULL_LEN
+ pyval = None
+ else:
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ pyval = (<RowLoader>loader).cloader.cload(
+ attval.value, attval.len)
+ else:
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(record, col, pyval)
+
+ if make_row is not tuple:
+ record = PyObject_CallFunctionObjArgs(
+ make_row, <PyObject *>record, NULL)
+ return record
+
+ cpdef object load_sequence(self, record: Sequence[Optional[Buffer]]):
+ cdef Py_ssize_t nfields = len(record)
+ out = PyTuple_New(nfields)
+ cdef PyObject *loader # borrowed RowLoader
+ cdef int col
+ cdef char *ptr
+ cdef Py_ssize_t size
+
+ row_loaders = self._row_loaders # avoid an incref/decref per item
+ if PyList_GET_SIZE(row_loaders) != nfields:
+ raise e.ProgrammingError(
+ f"cannot load sequence of {nfields} items:"
+ f" {len(self._row_loaders)} loaders registered")
+
+ for col in range(nfields):
+ item = record[col]
+ if item is None:
+ Py_INCREF(None)
+ PyTuple_SET_ITEM(out, col, None)
+ continue
+
+ loader = PyList_GET_ITEM(row_loaders, col)
+ if (<RowLoader>loader).cloader is not None:
+ _buffer_as_string_and_size(item, &ptr, &size)
+ pyval = (<RowLoader>loader).cloader.cload(ptr, size)
+ else:
+ pyval = PyObject_CallFunctionObjArgs(
+ (<RowLoader>loader).loadfunc, <PyObject *>item, NULL)
+
+ Py_INCREF(pyval)
+ PyTuple_SET_ITEM(out, col, pyval)
+
+ return out
+
+ def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ cdef PyObject *row_loader = self._c_get_loader(
+ <PyObject *>oid, <PyObject *>format)
+ return (<RowLoader>row_loader).pyloader
+
+ cdef PyObject *_c_get_loader(self, PyObject *oid, PyObject *fmt) except NULL:
+ """
+ Return a borrowed reference to the RowLoader instance for given oid/fmt
+ """
+ cdef PyObject *ptr
+ cdef PyObject *cache
+
+ if <object>fmt == PQ_TEXT:
+ if self._text_loaders is None:
+ self._text_loaders = {}
+ cache = <PyObject *>self._text_loaders
+ elif <object>fmt == PQ_BINARY:
+ if self._binary_loaders is None:
+ self._binary_loaders = {}
+ cache = <PyObject *>self._binary_loaders
+ else:
+ raise ValueError(
+ f"format should be a psycopg.pq.Format, not {format}")
+
+ ptr = PyDict_GetItem(<object>cache, <object>oid)
+ if ptr != NULL:
+ return ptr
+
+ loader_cls = self.adapters.get_loader(<object>oid, <object>fmt)
+ if loader_cls is None:
+ loader_cls = self.adapters.get_loader(oids.INVALID_OID, <object>fmt)
+ if loader_cls is None:
+ raise e.InterfaceError("unknown oid loader not found")
+
+ loader = PyObject_CallFunctionObjArgs(
+ loader_cls, oid, <PyObject *>self, NULL)
+
+ cdef RowLoader row_loader = RowLoader()
+ row_loader.pyloader = loader
+ row_loader.loadfunc = loader.load
+ if isinstance(loader, CLoader):
+ row_loader.cloader = <CLoader>loader
+
+ PyDict_SetItem(<object>cache, <object>oid, row_loader)
+ return <PyObject *>row_loader
+
+
+cdef object _as_row_dumper(object dumper):
+ cdef RowDumper row_dumper = RowDumper()
+
+ row_dumper.pydumper = dumper
+ row_dumper.dumpfunc = dumper.dump
+ row_dumper.oid = dumper.oid
+ row_dumper.format = dumper.format
+
+ if isinstance(dumper, CDumper):
+ row_dumper.cdumper = <CDumper>dumper
+
+ return row_dumper
+
+
+cdef Transformer _tx_from_context(object context):
+ if isinstance(context, Transformer):
+ return context
+ else:
+ return Transformer(context)
diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx
new file mode 100644
index 0000000..0af6c57
--- /dev/null
+++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx
@@ -0,0 +1,197 @@
+"""
+C implementation of waiting functions
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from cpython.object cimport PyObject_CallFunctionObjArgs
+
+cdef extern from *:
+ """
+#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
+
+#if defined(HAVE_POLL_H)
+#include <poll.h>
+#elif defined(HAVE_SYS_POLL_H)
+#include <sys/poll.h>
+#endif
+
+#else /* no poll available */
+
+#ifdef MS_WINDOWS
+#include <winsock2.h>
+#else
+#include <sys/select.h>
+#endif
+
+#endif /* HAVE_POLL */
+
+#define SELECT_EV_READ 1
+#define SELECT_EV_WRITE 2
+
+#define SEC_TO_MS 1000
+#define SEC_TO_US (1000 * 1000)
+
+/* Use select to wait for readiness on fileno.
+ *
+ * - Return SELECT_EV_* if the file is ready
+ * - Return 0 on timeout
+ * - Return -1 (and set an exception) on error.
+ *
+ * The wisdom of this function comes from:
+ *
+ * - PostgreSQL libpq (see src/interfaces/libpq/fe-misc.c)
+ * - Python select module (see Modules/selectmodule.c)
+ */
+static int
+wait_c_impl(int fileno, int wait, float timeout)
+{
+ int select_rv;
+ int rv = 0;
+
+#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
+
+ struct pollfd input_fd;
+ int timeout_ms;
+
+ input_fd.fd = fileno;
+ input_fd.events = POLLERR;
+ input_fd.revents = 0;
+
+ if (wait & SELECT_EV_READ) { input_fd.events |= POLLIN; }
+ if (wait & SELECT_EV_WRITE) { input_fd.events |= POLLOUT; }
+
+ if (timeout < 0.0) {
+ timeout_ms = -1;
+ } else {
+ timeout_ms = (int)(timeout * SEC_TO_MS);
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ errno = 0;
+ select_rv = poll(&input_fd, 1, timeout_ms);
+ Py_END_ALLOW_THREADS
+
+ if (PyErr_CheckSignals()) { goto finally; }
+
+ if (select_rv < 0) {
+ goto error;
+ }
+
+ if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
+ if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+
+#else
+
+ fd_set ifds;
+ fd_set ofds;
+ fd_set efds;
+ struct timeval tv, *tvptr;
+
+#ifndef MS_WINDOWS
+ if (fileno >= 1024) {
+ PyErr_SetString(
+ PyExc_ValueError, /* same exception of Python's 'select.select()' */
+ "connection file descriptor out of range for 'select()'");
+ return -1;
+ }
+#endif
+
+ FD_ZERO(&ifds);
+ FD_ZERO(&ofds);
+ FD_ZERO(&efds);
+
+ if (wait & SELECT_EV_READ) { FD_SET(fileno, &ifds); }
+ if (wait & SELECT_EV_WRITE) { FD_SET(fileno, &ofds); }
+ FD_SET(fileno, &efds);
+
+ /* Compute appropriate timeout interval */
+ if (timeout < 0.0) {
+ tvptr = NULL;
+ }
+ else {
+ tv.tv_sec = (int)timeout;
+ tv.tv_usec = (int)(((long)timeout * SEC_TO_US) % SEC_TO_US);
+ tvptr = &tv;
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ errno = 0;
+ select_rv = select(fileno + 1, &ifds, &ofds, &efds, tvptr);
+ Py_END_ALLOW_THREADS
+
+ if (PyErr_CheckSignals()) { goto finally; }
+
+ if (select_rv < 0) {
+ goto error;
+ }
+
+ if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+ if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+
+#endif /* HAVE_POLL */
+
+ return rv;
+
+error:
+
+#ifdef MS_WINDOWS
+ if (select_rv == SOCKET_ERROR) {
+ PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError());
+ }
+#else
+ if (select_rv < 0) {
+ PyErr_SetFromErrno(PyExc_OSError);
+ }
+#endif
+ else {
+ PyErr_SetString(PyExc_OSError, "unexpected error from select()");
+ }
+
+finally:
+
+ return -1;
+
+}
+ """
+ cdef int wait_c_impl(int fileno, int wait, float timeout) except -1
+
+
+def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV:
+ """
+ Wait for a generator using poll or select.
+ """
+ cdef float ctimeout
+ cdef int wait, ready
+ cdef PyObject *pyready
+
+ if timeout is None:
+ ctimeout = -1.0
+ else:
+ ctimeout = float(timeout)
+ if ctimeout < 0.0:
+ ctimeout = -1.0
+
+ send = gen.send
+
+ try:
+ wait = next(gen)
+
+ while True:
+ ready = wait_c_impl(fileno, wait, ctimeout)
+ if ready == 0:
+ continue
+ elif ready == READY_R:
+ pyready = <PyObject *>PY_READY_R
+ elif ready == READY_RW:
+ pyready = <PyObject *>PY_READY_RW
+ elif ready == READY_W:
+ pyready = <PyObject *>PY_READY_W
+ else:
+ raise AssertionError(f"unexpected ready value: {ready}")
+
+ wait = PyObject_CallFunctionObjArgs(send, pyready, NULL)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd
new file mode 100644
index 0000000..57825dd
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq.pxd
@@ -0,0 +1,78 @@
+# Include pid_t but Windows doesn't have it
+# Don't use "IF" so that the generated C is portable and can be included
+# in the sdist.
+cdef extern from * nogil:
+ """
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ typedef signed pid_t;
+#else
+ #include <fcntl.h>
+#endif
+ """
+ ctypedef signed pid_t
+
+from psycopg_c.pq cimport libpq
+
+ctypedef char *(*conn_bytes_f) (const libpq.PGconn *)
+ctypedef int(*conn_int_f) (const libpq.PGconn *)
+
+
+cdef class PGconn:
+ cdef libpq.PGconn* _pgconn_ptr
+ cdef object __weakref__
+ cdef public object notice_handler
+ cdef public object notify_handler
+ cdef pid_t _procpid
+
+ @staticmethod
+ cdef PGconn _from_ptr(libpq.PGconn *ptr)
+
+ cpdef int flush(self) except -1
+ cpdef object notifies(self)
+
+
+cdef class PGresult:
+ cdef libpq.PGresult* _pgresult_ptr
+
+ @staticmethod
+ cdef PGresult _from_ptr(libpq.PGresult *ptr)
+
+
+cdef class PGcancel:
+ cdef libpq.PGcancel* pgcancel_ptr
+
+ @staticmethod
+ cdef PGcancel _from_ptr(libpq.PGcancel *ptr)
+
+
+cdef class Escaping:
+ cdef PGconn conn
+
+ cpdef escape_literal(self, data)
+ cpdef escape_identifier(self, data)
+ cpdef escape_string(self, data)
+ cpdef escape_bytea(self, data)
+ cpdef unescape_bytea(self, const unsigned char *data)
+
+
+cdef class PQBuffer:
+ cdef unsigned char *buf
+ cdef Py_ssize_t len
+
+ @staticmethod
+ cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length)
+
+
+cdef class ViewBuffer:
+ cdef unsigned char *buf
+ cdef Py_ssize_t len
+ cdef object obj
+
+ @staticmethod
+ cdef ViewBuffer _from_buffer(
+ object obj, unsigned char *buf, Py_ssize_t length)
+
+
+cdef int _buffer_as_string_and_size(
+ data: "Buffer", char **ptr, Py_ssize_t *length
+) except -1
diff --git a/psycopg_c/psycopg_c/pq.pyx b/psycopg_c/psycopg_c/pq.pyx
new file mode 100644
index 0000000..d397c17
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq.pyx
@@ -0,0 +1,38 @@
+"""
+libpq Python wrapper using cython bindings.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c.pq cimport libpq
+
+import logging
+
+from psycopg import errors as e
+from psycopg.pq import Format
+from psycopg.pq.misc import error_message
+
+logger = logging.getLogger("psycopg")
+
+__impl__ = 'c'
+__build_version__ = libpq.PG_VERSION_NUM
+
+
+def version():
+ return libpq.PQlibVersion()
+
+
+include "pq/pgconn.pyx"
+include "pq/pgresult.pyx"
+include "pq/pgcancel.pyx"
+include "pq/conninfo.pyx"
+include "pq/escaping.pyx"
+include "pq/pqbuffer.pyx"
+
+
+# importing the ssl module sets up Python's libcrypto callbacks
+import ssl # noqa
+
+# disable libcrypto setup in libpq, so it won't stomp on the callbacks
+# that have already been set up
+libpq.PQinitOpenSSL(1, 0)
diff --git a/psycopg_c/psycopg_c/pq/__init__.pxd b/psycopg_c/psycopg_c/pq/__init__.pxd
new file mode 100644
index 0000000..ce8c60c
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/__init__.pxd
@@ -0,0 +1,9 @@
+"""
+psycopg_c.pq cython module.
+
+This file is necessary to allow c-importing pxd files from this directory.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg_c.pq cimport libpq
diff --git a/psycopg_c/psycopg_c/pq/conninfo.pyx b/psycopg_c/psycopg_c/pq/conninfo.pyx
new file mode 100644
index 0000000..3443de1
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/conninfo.pyx
@@ -0,0 +1,61 @@
+"""
+psycopg_c.pq.Conninfo object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from psycopg.pq.misc import ConninfoOption
+
+
+class Conninfo:
+ @classmethod
+ def get_defaults(cls) -> List[ConninfoOption]:
+ cdef libpq.PQconninfoOption *opts = libpq.PQconndefaults()
+ if opts is NULL :
+ raise MemoryError("couldn't allocate connection defaults")
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ @classmethod
+ def parse(cls, const char *conninfo) -> List[ConninfoOption]:
+ cdef char *errmsg = NULL
+ cdef libpq.PQconninfoOption *opts = libpq.PQconninfoParse(conninfo, &errmsg)
+ if opts is NULL:
+ if errmsg is NULL:
+ raise MemoryError("couldn't allocate on conninfo parse")
+ else:
+ exc = e.OperationalError(errmsg.decode("utf8", "replace"))
+ libpq.PQfreemem(errmsg)
+ raise exc
+
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ def __repr__(self):
+ return f"<{type(self).__name__} ({self.keyword.decode('ascii')})>"
+
+
+cdef _options_from_array(libpq.PQconninfoOption *opts):
+ rv = []
+ cdef int i = 0
+ cdef libpq.PQconninfoOption* opt
+ while True:
+ opt = opts + i
+ if opt.keyword is NULL:
+ break
+ rv.append(
+ ConninfoOption(
+ keyword=opt.keyword,
+ envvar=opt.envvar if opt.envvar is not NULL else None,
+ compiled=opt.compiled if opt.compiled is not NULL else None,
+ val=opt.val if opt.val is not NULL else None,
+ label=opt.label if opt.label is not NULL else None,
+ dispchar=opt.dispchar if opt.dispchar is not NULL else None,
+ dispsize=opt.dispsize,
+ )
+ )
+ i += 1
+
+ return rv
diff --git a/psycopg_c/psycopg_c/pq/escaping.pyx b/psycopg_c/psycopg_c/pq/escaping.pyx
new file mode 100644
index 0000000..f0a44d3
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/escaping.pyx
@@ -0,0 +1,132 @@
+"""
+psycopg_c.pq.Escaping object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from libc.string cimport strlen
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+
+cdef class Escaping:
+ def __init__(self, PGconn conn = None):
+ self.conn = conn
+
+ cpdef escape_literal(self, data):
+ cdef char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if self.conn is None:
+ raise e.OperationalError("escape_literal failed: no connection provided")
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ out = libpq.PQescapeLiteral(self.conn._pgconn_ptr, ptr, length)
+ if out is NULL:
+ raise e.OperationalError(
+ f"escape_literal failed: {error_message(self.conn)}"
+ )
+
+ rv = out[:strlen(out)]
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef escape_identifier(self, data):
+ cdef char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is None:
+ raise e.OperationalError("escape_identifier failed: no connection provided")
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ out = libpq.PQescapeIdentifier(self.conn._pgconn_ptr, ptr, length)
+ if out is NULL:
+ raise e.OperationalError(
+ f"escape_identifier failed: {error_message(self.conn)}"
+ )
+
+ rv = out[:strlen(out)]
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef escape_string(self, data):
+ cdef int error
+ cdef size_t len_out
+ cdef char *ptr
+ cdef char *buf_out
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is not None:
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ buf_out = <char *>PyMem_Malloc(length * 2 + 1)
+ len_out = libpq.PQescapeStringConn(
+ self.conn._pgconn_ptr, buf_out, ptr, length, &error
+ )
+ if error:
+ PyMem_Free(buf_out)
+ raise e.OperationalError(
+ f"escape_string failed: {error_message(self.conn)}"
+ )
+
+ else:
+ buf_out = <char *>PyMem_Malloc(length * 2 + 1)
+ len_out = libpq.PQescapeString(buf_out, ptr, length)
+
+ rv = buf_out[:len_out]
+ PyMem_Free(buf_out)
+ return rv
+
+ cpdef escape_bytea(self, data):
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if self.conn is not None and self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ _buffer_as_string_and_size(data, &ptr, &length)
+
+ if self.conn is not None:
+ out = libpq.PQescapeByteaConn(
+ self.conn._pgconn_ptr, <unsigned char *>ptr, length, &len_out)
+ else:
+ out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out)
+
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out - 1] # out includes final 0
+ libpq.PQfreemem(out)
+ return rv
+
+ cpdef unescape_bytea(self, const unsigned char *data):
+ # not needed, but let's keep it symmetric with the escaping:
+ # if a connection is passed in, it must be valid.
+ if self.conn is not None:
+ if self.conn._pgconn_ptr is NULL:
+ raise e.OperationalError("the connection is closed")
+
+ cdef size_t len_out
+ cdef unsigned char *out = libpq.PQunescapeBytea(data, &len_out)
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out]
+ libpq.PQfreemem(out)
+ return rv
diff --git a/psycopg_c/psycopg_c/pq/libpq.pxd b/psycopg_c/psycopg_c/pq/libpq.pxd
new file mode 100644
index 0000000..5e05e40
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/libpq.pxd
@@ -0,0 +1,321 @@
+"""
+Libpq header definition for the cython psycopg.pq implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cdef extern from "stdio.h":
+
+ ctypedef struct FILE:
+ pass
+
+cdef extern from "pg_config.h":
+
+ int PG_VERSION_NUM
+
+
+cdef extern from "libpq-fe.h":
+
+ # structures and types
+
+ ctypedef unsigned int Oid
+
+ ctypedef struct PGconn:
+ pass
+
+ ctypedef struct PGresult:
+ pass
+
+ ctypedef struct PQconninfoOption:
+ char *keyword
+ char *envvar
+ char *compiled
+ char *val
+ char *label
+ char *dispchar
+ int dispsize
+
+ ctypedef struct PGnotify:
+ char *relname
+ int be_pid
+ char *extra
+
+ ctypedef struct PGcancel:
+ pass
+
+ ctypedef struct PGresAttDesc:
+ char *name
+ Oid tableid
+ int columnid
+ int format
+ Oid typid
+ int typlen
+ int atttypmod
+
+ # enums
+
+ ctypedef enum PostgresPollingStatusType:
+ PGRES_POLLING_FAILED = 0
+ PGRES_POLLING_READING
+ PGRES_POLLING_WRITING
+ PGRES_POLLING_OK
+ PGRES_POLLING_ACTIVE
+
+
+ ctypedef enum PGPing:
+ PQPING_OK
+ PQPING_REJECT
+ PQPING_NO_RESPONSE
+ PQPING_NO_ATTEMPT
+
+ ctypedef enum ConnStatusType:
+ CONNECTION_OK
+ CONNECTION_BAD
+ CONNECTION_STARTED
+ CONNECTION_MADE
+ CONNECTION_AWAITING_RESPONSE
+ CONNECTION_AUTH_OK
+ CONNECTION_SETENV
+ CONNECTION_SSL_STARTUP
+ CONNECTION_NEEDED
+ CONNECTION_CHECK_WRITABLE
+ CONNECTION_GSS_STARTUP
+ # CONNECTION_CHECK_TARGET PG 12
+
+ ctypedef enum PGTransactionStatusType:
+ PQTRANS_IDLE
+ PQTRANS_ACTIVE
+ PQTRANS_INTRANS
+ PQTRANS_INERROR
+ PQTRANS_UNKNOWN
+
+ ctypedef enum ExecStatusType:
+ PGRES_EMPTY_QUERY = 0
+ PGRES_COMMAND_OK
+ PGRES_TUPLES_OK
+ PGRES_COPY_OUT
+ PGRES_COPY_IN
+ PGRES_BAD_RESPONSE
+ PGRES_NONFATAL_ERROR
+ PGRES_FATAL_ERROR
+ PGRES_COPY_BOTH
+ PGRES_SINGLE_TUPLE
+ PGRES_PIPELINE_SYNC
+ PGRES_PIPELINE_ABORT
+
+ # 33.1. Database Connection Control Functions
+ PGconn *PQconnectdb(const char *conninfo)
+ PGconn *PQconnectStart(const char *conninfo)
+ PostgresPollingStatusType PQconnectPoll(PGconn *conn) nogil
+ PQconninfoOption *PQconndefaults()
+ PQconninfoOption *PQconninfo(PGconn *conn)
+ PQconninfoOption *PQconninfoParse(const char *conninfo, char **errmsg)
+ void PQfinish(PGconn *conn)
+ void PQreset(PGconn *conn)
+ int PQresetStart(PGconn *conn)
+ PostgresPollingStatusType PQresetPoll(PGconn *conn)
+ PGPing PQping(const char *conninfo)
+
+ # 33.2. Connection Status Functions
+ char *PQdb(const PGconn *conn)
+ char *PQuser(const PGconn *conn)
+ char *PQpass(const PGconn *conn)
+ char *PQhost(const PGconn *conn)
+ char *PQhostaddr(const PGconn *conn)
+ char *PQport(const PGconn *conn)
+ char *PQtty(const PGconn *conn)
+ char *PQoptions(const PGconn *conn)
+ ConnStatusType PQstatus(const PGconn *conn)
+ PGTransactionStatusType PQtransactionStatus(const PGconn *conn)
+ const char *PQparameterStatus(const PGconn *conn, const char *paramName)
+ int PQprotocolVersion(const PGconn *conn)
+ int PQserverVersion(const PGconn *conn)
+ char *PQerrorMessage(const PGconn *conn)
+ int PQsocket(const PGconn *conn) nogil
+ int PQbackendPID(const PGconn *conn)
+ int PQconnectionNeedsPassword(const PGconn *conn)
+ int PQconnectionUsedPassword(const PGconn *conn)
+ int PQsslInUse(PGconn *conn) # TODO: const in PG 12 docs - verify/report
+ # TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
+
+ # 33.3. Command Execution Functions
+ PGresult *PQexec(PGconn *conn, const char *command) nogil
+ PGresult *PQexecParams(PGconn *conn,
+ const char *command,
+ int nParams,
+ const Oid *paramTypes,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ PGresult *PQprepare(PGconn *conn,
+ const char *stmtName,
+ const char *query,
+ int nParams,
+ const Oid *paramTypes) nogil
+ PGresult *PQexecPrepared(PGconn *conn,
+ const char *stmtName,
+ int nParams,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ PGresult *PQdescribePrepared(PGconn *conn, const char *stmtName) nogil
+ PGresult *PQdescribePortal(PGconn *conn, const char *portalName) nogil
+ ExecStatusType PQresultStatus(const PGresult *res) nogil
+ # PQresStatus: not needed, we have pretty enums
+ char *PQresultErrorMessage(const PGresult *res) nogil
+ # TODO: PQresultVerboseErrorMessage
+ char *PQresultErrorField(const PGresult *res, int fieldcode) nogil
+ void PQclear(PGresult *res) nogil
+
+ # 33.3.2. Retrieving Query Result Information
+ int PQntuples(const PGresult *res)
+ int PQnfields(const PGresult *res)
+ char *PQfname(const PGresult *res, int column_number)
+ int PQfnumber(const PGresult *res, const char *column_name)
+ Oid PQftable(const PGresult *res, int column_number)
+ int PQftablecol(const PGresult *res, int column_number)
+ int PQfformat(const PGresult *res, int column_number)
+ Oid PQftype(const PGresult *res, int column_number)
+ int PQfmod(const PGresult *res, int column_number)
+ int PQfsize(const PGresult *res, int column_number)
+ int PQbinaryTuples(const PGresult *res)
+ char *PQgetvalue(const PGresult *res, int row_number, int column_number)
+ int PQgetisnull(const PGresult *res, int row_number, int column_number)
+ int PQgetlength(const PGresult *res, int row_number, int column_number)
+ int PQnparams(const PGresult *res)
+ Oid PQparamtype(const PGresult *res, int param_number)
+ # PQprint: pretty useless
+
+ # 33.3.3. Retrieving Other Result Information
+ char *PQcmdStatus(PGresult *res)
+ char *PQcmdTuples(PGresult *res)
+ Oid PQoidValue(const PGresult *res)
+
+ # 33.3.4. Escaping Strings for Inclusion in SQL Commands
+ char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length)
+ char *PQescapeLiteral(PGconn *conn, const char *str, size_t length)
+ size_t PQescapeStringConn(PGconn *conn,
+ char *to, const char *from_, size_t length,
+ int *error)
+ size_t PQescapeString(char *to, const char *from_, size_t length)
+ unsigned char *PQescapeByteaConn(PGconn *conn,
+ const unsigned char *src,
+ size_t from_length,
+ size_t *to_length)
+ unsigned char *PQescapeBytea(const unsigned char *src,
+ size_t from_length,
+ size_t *to_length)
+ unsigned char *PQunescapeBytea(const unsigned char *src, size_t *to_length)
+
+
+ # 33.4. Asynchronous Command Processing
+ int PQsendQuery(PGconn *conn, const char *command) nogil
+ int PQsendQueryParams(PGconn *conn,
+ const char *command,
+ int nParams,
+ const Oid *paramTypes,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ int PQsendPrepare(PGconn *conn,
+ const char *stmtName,
+ const char *query,
+ int nParams,
+ const Oid *paramTypes) nogil
+ int PQsendQueryPrepared(PGconn *conn,
+ const char *stmtName,
+ int nParams,
+ const char * const *paramValues,
+ const int *paramLengths,
+ const int *paramFormats,
+ int resultFormat) nogil
+ int PQsendDescribePrepared(PGconn *conn, const char *stmtName) nogil
+ int PQsendDescribePortal(PGconn *conn, const char *portalName) nogil
+ PGresult *PQgetResult(PGconn *conn) nogil
+ int PQconsumeInput(PGconn *conn) nogil
+ int PQisBusy(PGconn *conn) nogil
+ int PQsetnonblocking(PGconn *conn, int arg) nogil
+ int PQisnonblocking(const PGconn *conn)
+ int PQflush(PGconn *conn) nogil
+
+ # 33.5. Retrieving Query Results Row-by-Row
+ int PQsetSingleRowMode(PGconn *conn)
+
+ # 33.6. Canceling Queries in Progress
+ PGcancel *PQgetCancel(PGconn *conn)
+ void PQfreeCancel(PGcancel *cancel)
+ int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize)
+
+ # 33.8. Asynchronous Notification
+ PGnotify *PQnotifies(PGconn *conn) nogil
+
+ # 33.9. Functions Associated with the COPY Command
+ int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) nogil
+ int PQputCopyEnd(PGconn *conn, const char *errormsg) nogil
+ int PQgetCopyData(PGconn *conn, char **buffer, int async) nogil
+
+ # 33.10. Control Functions
+ void PQtrace(PGconn *conn, FILE *stream);
+ void PQsetTraceFlags(PGconn *conn, int flags);
+ void PQuntrace(PGconn *conn);
+
+ # 33.11. Miscellaneous Functions
+ void PQfreemem(void *ptr) nogil
+ void PQconninfoFree(PQconninfoOption *connOptions)
+ char *PQencryptPasswordConn(
+ PGconn *conn, const char *passwd, const char *user, const char *algorithm);
+ PGresult *PQmakeEmptyPGresult(PGconn *conn, ExecStatusType status)
+ int PQsetResultAttrs(PGresult *res, int numAttributes, PGresAttDesc *attDescs)
+ int PQlibVersion()
+
+ # 33.12. Notice Processing
+ ctypedef void (*PQnoticeReceiver)(void *arg, const PGresult *res)
+ PQnoticeReceiver PQsetNoticeReceiver(
+ PGconn *conn, PQnoticeReceiver prog, void *arg)
+
+ # 33.18. SSL Support
+ void PQinitOpenSSL(int do_ssl, int do_crypto)
+
+ # 34.5 Pipeline Mode
+
+ ctypedef enum PGpipelineStatus:
+ PQ_PIPELINE_OFF
+ PQ_PIPELINE_ON
+ PQ_PIPELINE_ABORTED
+
+ PGpipelineStatus PQpipelineStatus(const PGconn *conn)
+ int PQenterPipelineMode(PGconn *conn)
+ int PQexitPipelineMode(PGconn *conn)
+ int PQpipelineSync(PGconn *conn)
+ int PQsendFlushRequest(PGconn *conn)
+
+cdef extern from *:
+ """
+/* Hack to allow the use of old libpq versions */
+#if PG_VERSION_NUM < 100000
+#define PQencryptPasswordConn(conn, passwd, user, algorithm) NULL
+#endif
+
+#if PG_VERSION_NUM < 120000
+#define PQhostaddr(conn) NULL
+#endif
+
+#if PG_VERSION_NUM < 140000
+#define PGRES_PIPELINE_SYNC 10
+#define PGRES_PIPELINE_ABORTED 11
+typedef enum {
+ PQ_PIPELINE_OFF,
+ PQ_PIPELINE_ON,
+ PQ_PIPELINE_ABORTED
+} PGpipelineStatus;
+#define PQpipelineStatus(conn) PQ_PIPELINE_OFF
+#define PQenterPipelineMode(conn) 0
+#define PQexitPipelineMode(conn) 1
+#define PQpipelineSync(conn) 0
+#define PQsendFlushRequest(conn) 0
+#define PQsetTraceFlags(conn, stream) do {} while (0)
+#endif
+"""
diff --git a/psycopg_c/psycopg_c/pq/pgcancel.pyx b/psycopg_c/psycopg_c/pq/pgcancel.pyx
new file mode 100644
index 0000000..b7cbb70
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgcancel.pyx
@@ -0,0 +1,32 @@
+"""
+psycopg_c.pq.PGcancel object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+cdef class PGcancel:
+ def __cinit__(self):
+ self.pgcancel_ptr = NULL
+
+ @staticmethod
+ cdef PGcancel _from_ptr(libpq.PGcancel *ptr):
+ cdef PGcancel rv = PGcancel.__new__(PGcancel)
+ rv.pgcancel_ptr = ptr
+ return rv
+
+ def __dealloc__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ if self.pgcancel_ptr is not NULL:
+ libpq.PQfreeCancel(self.pgcancel_ptr)
+ self.pgcancel_ptr = NULL
+
+ def cancel(self) -> None:
+ cdef char buf[256]
+ cdef int res = libpq.PQcancel(self.pgcancel_ptr, buf, sizeof(buf))
+ if not res:
+ raise e.OperationalError(
+ f"cancel failed: {buf.decode('utf8', 'ignore')}"
+ )
diff --git a/psycopg_c/psycopg_c/pq/pgconn.pyx b/psycopg_c/psycopg_c/pq/pgconn.pyx
new file mode 100644
index 0000000..4a60530
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgconn.pyx
@@ -0,0 +1,733 @@
+"""
+psycopg_c.pq.PGconn object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cdef extern from * nogil:
+ """
+#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS)
+ /* We don't need a real definition for this because Windows is not affected
+ * by the issue caused by closing the fds after fork.
+ */
+ #define getpid() (0)
+#else
+ #include <unistd.h>
+#endif
+ """
+ pid_t getpid()
+
+from libc.stdio cimport fdopen
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+from cpython.bytes cimport PyBytes_AsString
+from cpython.memoryview cimport PyMemoryView_FromObject
+
+import sys
+
+from psycopg.pq import Format as PqFormat, Trace
+from psycopg.pq.misc import PGnotify, connection_summary
+from psycopg_c.pq cimport PQBuffer
+
+
+cdef class PGconn:
+ @staticmethod
+ cdef PGconn _from_ptr(libpq.PGconn *ptr):
+ cdef PGconn rv = PGconn.__new__(PGconn)
+ rv._pgconn_ptr = ptr
+
+ libpq.PQsetNoticeReceiver(ptr, notice_receiver, <void *>rv)
+ return rv
+
+ def __cinit__(self):
+ self._pgconn_ptr = NULL
+ self._procpid = getpid()
+
+ def __dealloc__(self):
+ # Close the connection only if it was created in this process,
+ # not if this object is being GC'd after fork.
+ if self._procpid == getpid():
+ self.finish()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ info = connection_summary(self)
+ return f"<{cls} {info} at 0x{id(self):x}>"
+
+ @classmethod
+ def connect(cls, const char *conninfo) -> PGconn:
+ cdef libpq.PGconn* pgconn = libpq.PQconnectdb(conninfo)
+ if not pgconn:
+ raise MemoryError("couldn't allocate PGconn")
+
+ return PGconn._from_ptr(pgconn)
+
+ @classmethod
+ def connect_start(cls, const char *conninfo) -> PGconn:
+ cdef libpq.PGconn* pgconn = libpq.PQconnectStart(conninfo)
+ if not pgconn:
+ raise MemoryError("couldn't allocate PGconn")
+
+ return PGconn._from_ptr(pgconn)
+
+ def connect_poll(self) -> int:
+ return _call_int(self, <conn_int_f>libpq.PQconnectPoll)
+
+ def finish(self) -> None:
+ if self._pgconn_ptr is not NULL:
+ libpq.PQfinish(self._pgconn_ptr)
+ self._pgconn_ptr = NULL
+
+ @property
+ def pgconn_ptr(self) -> Optional[int]:
+ if self._pgconn_ptr:
+ return <long long><void *>self._pgconn_ptr
+ else:
+ return None
+
+ @property
+ def info(self) -> List["ConninfoOption"]:
+ _ensure_pgconn(self)
+ cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr)
+ if opts is NULL:
+ raise MemoryError("couldn't allocate connection info")
+ rv = _options_from_array(opts)
+ libpq.PQconninfoFree(opts)
+ return rv
+
+ def reset(self) -> None:
+ _ensure_pgconn(self)
+ libpq.PQreset(self._pgconn_ptr)
+
+ def reset_start(self) -> None:
+ if not libpq.PQresetStart(self._pgconn_ptr):
+ raise e.OperationalError("couldn't reset connection")
+
+ def reset_poll(self) -> int:
+ return _call_int(self, <conn_int_f>libpq.PQresetPoll)
+
+ @classmethod
+ def ping(self, const char *conninfo) -> int:
+ return libpq.PQping(conninfo)
+
+ @property
+ def db(self) -> bytes:
+ return _call_bytes(self, libpq.PQdb)
+
+ @property
+ def user(self) -> bytes:
+ return _call_bytes(self, libpq.PQuser)
+
+ @property
+ def password(self) -> bytes:
+ return _call_bytes(self, libpq.PQpass)
+
+ @property
+ def host(self) -> bytes:
+ return _call_bytes(self, libpq.PQhost)
+
+ @property
+ def hostaddr(self) -> bytes:
+ if libpq.PG_VERSION_NUM < 120000:
+ raise e.NotSupportedError(
+ f"PQhostaddr requires libpq from PostgreSQL 12,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+
+ _ensure_pgconn(self)
+ cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr)
+ assert rv is not NULL
+ return rv
+
+ @property
+ def port(self) -> bytes:
+ return _call_bytes(self, libpq.PQport)
+
+ @property
+ def tty(self) -> bytes:
+ return _call_bytes(self, libpq.PQtty)
+
+ @property
+ def options(self) -> bytes:
+ return _call_bytes(self, libpq.PQoptions)
+
+ @property
+ def status(self) -> int:
+ return libpq.PQstatus(self._pgconn_ptr)
+
+ @property
+ def transaction_status(self) -> int:
+ return libpq.PQtransactionStatus(self._pgconn_ptr)
+
+ def parameter_status(self, const char *name) -> Optional[bytes]:
+ _ensure_pgconn(self)
+ cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def error_message(self) -> bytes:
+ return libpq.PQerrorMessage(self._pgconn_ptr)
+
+ @property
+ def protocol_version(self) -> int:
+ return _call_int(self, libpq.PQprotocolVersion)
+
+ @property
+ def server_version(self) -> int:
+ return _call_int(self, libpq.PQserverVersion)
+
+ @property
+ def socket(self) -> int:
+ rv = _call_int(self, libpq.PQsocket)
+ if rv == -1:
+ raise e.OperationalError("the connection is lost")
+ return rv
+
+ @property
+ def backend_pid(self) -> int:
+ return _call_int(self, libpq.PQbackendPID)
+
+ @property
+ def needs_password(self) -> bool:
+ return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr))
+
+ @property
+ def used_password(self) -> bool:
+ return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr))
+
+ @property
+ def ssl_in_use(self) -> bool:
+ return bool(_call_int(self, <conn_int_f>libpq.PQsslInUse))
+
+ def exec_(self, const char *command) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *pgresult
+ with nogil:
+ pgresult = libpq.PQexec(self._pgconn_ptr, command)
+ if pgresult is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+
+ return PGresult._from_ptr(pgresult)
+
+ def send_query(self, const char *command) -> None:
+ _ensure_pgconn(self)
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQuery(self._pgconn_ptr, command)
+ if not rv:
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
+
+ def exec_params(
+ self,
+ const char *command,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, param_types, param_formats)
+
+ cdef libpq.PGresult *pgresult
+ with nogil:
+ pgresult = libpq.PQexecParams(
+ self._pgconn_ptr, command, <int>cnparams, ctypes,
+ <const char *const *>cvalues, clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if pgresult is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(pgresult)
+
+ def send_query_params(
+ self,
+ const char *command,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]] = None,
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, param_types, param_formats)
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQueryParams(
+ self._pgconn_ptr, command, <int>cnparams, ctypes,
+ <const char *const *>cvalues, clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if not rv:
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_prepare(
+ self,
+ const char *name,
+ const char *command,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef int i
+ cdef Py_ssize_t nparams = len(param_types) if param_types else 0
+ cdef libpq.Oid *atypes = NULL
+ if nparams:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = param_types[i]
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendPrepare(
+ self._pgconn_ptr, name, command, <int>nparams, atypes
+ )
+ PyMem_Free(atypes)
+ if not rv:
+ raise e.OperationalError(
+ f"sending query and params failed: {error_message(self)}"
+ )
+
+ def send_query_prepared(
+ self,
+ const char *name,
+ param_values: Optional[Sequence[Optional[bytes]]],
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> None:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, None, param_formats)
+
+ cdef int rv
+ with nogil:
+ rv = libpq.PQsendQueryPrepared(
+ self._pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues,
+ clengths, cformats, result_format)
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if not rv:
+ raise e.OperationalError(
+ f"sending prepared query failed: {error_message(self)}"
+ )
+
+ def prepare(
+ self,
+ const char *name,
+ const char *command,
+ param_types: Optional[Sequence[int]] = None,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef int i
+ cdef Py_ssize_t nparams = len(param_types) if param_types else 0
+ cdef libpq.Oid *atypes = NULL
+ if nparams:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = param_types[i]
+
+ cdef libpq.PGresult *rv
+ with nogil:
+ rv = libpq.PQprepare(
+ self._pgconn_ptr, name, command, <int>nparams, atypes)
+ PyMem_Free(atypes)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def exec_prepared(
+ self,
+ const char *name,
+ param_values: Optional[Sequence[bytes]],
+ param_formats: Optional[Sequence[int]] = None,
+ int result_format = PqFormat.TEXT,
+ ) -> PGresult:
+ _ensure_pgconn(self)
+
+ cdef Py_ssize_t cnparams
+ cdef libpq.Oid *ctypes
+ cdef char *const *cvalues
+ cdef int *clengths
+ cdef int *cformats
+ cnparams, ctypes, cvalues, clengths, cformats = _query_params_args(
+ param_values, None, param_formats)
+
+ cdef libpq.PGresult *rv
+ with nogil:
+ rv = libpq.PQexecPrepared(
+ self._pgconn_ptr, name, <int>cnparams,
+ <const char *const *>cvalues,
+ clengths, cformats, result_format)
+
+ _clear_query_params(ctypes, cvalues, clengths, cformats)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def describe_prepared(self, const char *name) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def send_describe_prepared(self, const char *name) -> None:
+ _ensure_pgconn(self)
+ cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name)
+ if not rv:
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def describe_portal(self, const char *name) -> PGresult:
+ _ensure_pgconn(self)
+ cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name)
+ if rv is NULL:
+ raise MemoryError("couldn't allocate PGresult")
+ return PGresult._from_ptr(rv)
+
+ def send_describe_portal(self, const char *name) -> None:
+ _ensure_pgconn(self)
+ cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name)
+ if not rv:
+ raise e.OperationalError(
+ f"sending describe prepared failed: {error_message(self)}"
+ )
+
+ def get_result(self) -> Optional["PGresult"]:
+ cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr)
+ if pgresult is NULL:
+ return None
+ return PGresult._from_ptr(pgresult)
+
+ def consume_input(self) -> None:
+ if 1 != libpq.PQconsumeInput(self._pgconn_ptr):
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
+
+ def is_busy(self) -> int:
+ cdef int rv
+ with nogil:
+ rv = libpq.PQisBusy(self._pgconn_ptr)
+ return rv
+
+ @property
+ def nonblocking(self) -> int:
+ return libpq.PQisnonblocking(self._pgconn_ptr)
+
+ @nonblocking.setter
+ def nonblocking(self, int arg) -> None:
+ if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg):
+ raise e.OperationalError(f"setting nonblocking failed: {error_message(self)}")
+
+ cpdef int flush(self) except -1:
+ if self._pgconn_ptr == NULL:
+ raise e.OperationalError(f"flushing failed: the connection is closed")
+ cdef int rv = libpq.PQflush(self._pgconn_ptr)
+ if rv < 0:
+ raise e.OperationalError(f"flushing failed: {error_message(self)}")
+ return rv
+
+ def set_single_row_mode(self) -> None:
+ if not libpq.PQsetSingleRowMode(self._pgconn_ptr):
+ raise e.OperationalError("setting single row mode failed")
+
+ def get_cancel(self) -> PGcancel:
+ cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr)
+ if not ptr:
+ raise e.OperationalError("couldn't create cancel object")
+ return PGcancel._from_ptr(ptr)
+
+ cpdef object notifies(self):
+ cdef libpq.PGnotify *ptr
+ with nogil:
+ ptr = libpq.PQnotifies(self._pgconn_ptr)
+ if ptr:
+ ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra)
+ libpq.PQfreemem(ptr)
+ return ret
+ else:
+ return None
+
+ def put_copy_data(self, buffer) -> int:
+ cdef int rv
+ cdef char *cbuffer
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(buffer, &cbuffer, &length)
+ rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
+ return rv
+
+ def put_copy_end(self, error: Optional[bytes] = None) -> int:
+ cdef int rv
+ cdef const char *cerr = NULL
+ if error is not None:
+ cerr = PyBytes_AsString(error)
+ rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr)
+ if rv < 0:
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
+ return rv
+
+ def get_copy_data(self, int async_) -> Tuple[int, memoryview]:
+ cdef char *buffer_ptr = NULL
+ cdef int nbytes
+ nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_)
+ if nbytes == -2:
+ raise e.OperationalError(f"receiving copy data failed: {error_message(self)}")
+ if buffer_ptr is not NULL:
+ data = PyMemoryView_FromObject(
+ PQBuffer._from_buffer(<unsigned char *>buffer_ptr, nbytes))
+ return nbytes, data
+ else:
+ return nbytes, b"" # won't parse it, doesn't really be memoryview
+
+ def trace(self, fileno: int) -> None:
+ if sys.platform != "linux":
+ raise e.NotSupportedError("currently only supported on Linux")
+ stream = fdopen(fileno, b"w")
+ libpq.PQtrace(self._pgconn_ptr, stream)
+
+ def set_trace_flags(self, flags: Trace) -> None:
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQsetTraceFlags requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ libpq.PQsetTraceFlags(self._pgconn_ptr, flags)
+
+ def untrace(self) -> None:
+ libpq.PQuntrace(self._pgconn_ptr)
+
+ def encrypt_password(
+ self, const char *passwd, const char *user, algorithm = None
+ ) -> bytes:
+ if libpq.PG_VERSION_NUM < 100000:
+ raise e.NotSupportedError(
+ f"PQencryptPasswordConn requires libpq from PostgreSQL 10,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+
+ cdef char *out
+ cdef const char *calgo = NULL
+ if algorithm:
+ calgo = algorithm
+ out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo)
+ if not out:
+ raise e.OperationalError(
+ f"password encryption failed: {error_message(self)}"
+ )
+
+ rv = bytes(out)
+ libpq.PQfreemem(out)
+ return rv
+
+ def make_empty_result(self, int exec_status) -> PGresult:
+ cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult(
+ self._pgconn_ptr, <libpq.ExecStatusType>exec_status)
+ if not rv:
+ raise MemoryError("couldn't allocate empty PGresult")
+ return PGresult._from_ptr(rv)
+
+ @property
+ def pipeline_status(self) -> int:
+ """The current pipeline mode status.
+
+ For libpq < 14.0, always return 0 (PQ_PIPELINE_OFF).
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ return libpq.PQ_PIPELINE_OFF
+ cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr)
+ return status
+
+ def enter_pipeline_mode(self) -> None:
+ """Enter pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to enter the pipeline
+ mode.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQenterPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError("failed to enter pipeline mode")
+
+ def exit_pipeline_mode(self) -> None:
+ """Exit pipeline mode.
+
+ :raises ~e.OperationalError: in case of failure to exit the pipeline
+ mode.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQexitPipelineMode requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1:
+ raise e.OperationalError(error_message(self))
+
+ def pipeline_sync(self) -> None:
+ """Mark a synchronization point in a pipeline.
+
+ :raises ~e.OperationalError: if the connection is not in pipeline mode
+ or if sync failed.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQpipelineSync requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ rv = libpq.PQpipelineSync(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError("connection not in pipeline mode")
+ if rv != 1:
+ raise e.OperationalError("failed to sync pipeline")
+
+ def send_flush_request(self) -> None:
+ """Sends a request for the server to flush its output buffer.
+
+ :raises ~e.OperationalError: if the flush request failed.
+ """
+ if libpq.PG_VERSION_NUM < 140000:
+ raise e.NotSupportedError(
+ f"PQsendFlushRequest requires libpq from PostgreSQL 14,"
+ f" {libpq.PG_VERSION_NUM} available instead"
+ )
+ cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr)
+ if rv == 0:
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
+
+
+cdef int _ensure_pgconn(PGconn pgconn) except 0:
+ if pgconn._pgconn_ptr is not NULL:
+ return 1
+
+ raise e.OperationalError("the connection is closed")
+
+
+cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL:
+ """
+ Call one of the pgconn libpq functions returning a bytes pointer.
+ """
+ if not _ensure_pgconn(pgconn):
+ return NULL
+ cdef char *rv = func(pgconn._pgconn_ptr)
+ assert rv is not NULL
+ return rv
+
+
+cdef int _call_int(PGconn pgconn, conn_int_f func) except -2:
+ """
+ Call one of the pgconn libpq functions returning an int.
+ """
+ if not _ensure_pgconn(pgconn):
+ return -2
+ return func(pgconn._pgconn_ptr)
+
+
+cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil:
+ cdef PGconn pgconn = <object>arg
+ if pgconn.notice_handler is None:
+ return
+
+ cdef PGresult res = PGresult._from_ptr(<libpq.PGresult *>res_ptr)
+ try:
+ pgconn.notice_handler(res)
+ except Exception as e:
+ logger.exception("error in notice receiver: %s", e)
+ finally:
+ res._pgresult_ptr = NULL # avoid destroying the pgresult_ptr
+
+
+cdef (Py_ssize_t, libpq.Oid *, char * const*, int *, int *) _query_params_args(
+ list param_values: Optional[Sequence[Optional[bytes]]],
+ param_types: Optional[Sequence[int]],
+ list param_formats: Optional[Sequence[int]],
+) except *:
+ cdef int i
+
+ # the PostgresQuery converts the param_types to tuple, so this operation
+ # is most often no-op
+ cdef tuple tparam_types
+ if param_types is not None and not isinstance(param_types, tuple):
+ tparam_types = tuple(param_types)
+ else:
+ tparam_types = param_types
+
+ cdef Py_ssize_t nparams = len(param_values) if param_values else 0
+ if tparam_types is not None and len(tparam_types) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_types"
+ % (nparams, len(tparam_types))
+ )
+ if param_formats is not None and len(param_formats) != nparams:
+ raise ValueError(
+ "got %d param_values but %d param_formats"
+ % (nparams, len(param_formats))
+ )
+
+ cdef char **aparams = NULL
+ cdef int *alenghts = NULL
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ if nparams:
+ aparams = <char **>PyMem_Malloc(nparams * sizeof(char *))
+ alenghts = <int *>PyMem_Malloc(nparams * sizeof(int))
+ for i in range(nparams):
+ obj = param_values[i]
+ if obj is None:
+ aparams[i] = NULL
+ alenghts[i] = 0
+ else:
+ # TODO: it is a leak if this fails (but it should only fail
+ # on internal error, e.g. if obj is not a buffer)
+ _buffer_as_string_and_size(obj, &ptr, &length)
+ aparams[i] = ptr
+ alenghts[i] = <int>length
+
+ cdef libpq.Oid *atypes = NULL
+ if tparam_types:
+ atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid))
+ for i in range(nparams):
+ atypes[i] = tparam_types[i]
+
+ cdef int *aformats = NULL
+ if param_formats is not None:
+ aformats = <int *>PyMem_Malloc(nparams * sizeof(int *))
+ for i in range(nparams):
+ aformats[i] = param_formats[i]
+
+ return (nparams, atypes, aparams, alenghts, aformats)
+
+
+cdef void _clear_query_params(
+ libpq.Oid *ctypes, char *const *cvalues, int *clenghst, int *cformats
+):
+ PyMem_Free(ctypes)
+ PyMem_Free(<char **>cvalues)
+ PyMem_Free(clenghst)
+ PyMem_Free(cformats)
diff --git a/psycopg_c/psycopg_c/pq/pgresult.pyx b/psycopg_c/psycopg_c/pq/pgresult.pyx
new file mode 100644
index 0000000..6df42e8
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pgresult.pyx
@@ -0,0 +1,157 @@
+"""
+psycopg_c.pq.PGresult object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+from psycopg.pq.misc import PGresAttDesc
+from psycopg.pq._enums import ExecStatus
+
+
+@cython.freelist(8)
+cdef class PGresult:
+ def __cinit__(self):
+ self._pgresult_ptr = NULL
+
+ @staticmethod
+ cdef PGresult _from_ptr(libpq.PGresult *ptr):
+ cdef PGresult rv = PGresult.__new__(PGresult)
+ rv._pgresult_ptr = ptr
+ return rv
+
+ def __dealloc__(self) -> None:
+ self.clear()
+
+ def __repr__(self) -> str:
+ cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ status = ExecStatus(self.status)
+ return f"<{cls} [{status.name}] at 0x{id(self):x}>"
+
+ def clear(self) -> None:
+ if self._pgresult_ptr is not NULL:
+ libpq.PQclear(self._pgresult_ptr)
+ self._pgresult_ptr = NULL
+
+ @property
+ def pgresult_ptr(self) -> Optional[int]:
+ if self._pgresult_ptr:
+ return <long long><void *>self._pgresult_ptr
+ else:
+ return None
+
+ @property
+ def status(self) -> int:
+ return libpq.PQresultStatus(self._pgresult_ptr)
+
+ @property
+ def error_message(self) -> bytes:
+ return libpq.PQresultErrorMessage(self._pgresult_ptr)
+
+ def error_field(self, int fieldcode) -> Optional[bytes]:
+ cdef char * rv = libpq.PQresultErrorField(self._pgresult_ptr, fieldcode)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def ntuples(self) -> int:
+ return libpq.PQntuples(self._pgresult_ptr)
+
+ @property
+ def nfields(self) -> int:
+ return libpq.PQnfields(self._pgresult_ptr)
+
+ def fname(self, int column_number) -> Optional[bytes]:
+ cdef char *rv = libpq.PQfname(self._pgresult_ptr, column_number)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ def ftable(self, int column_number) -> int:
+ return libpq.PQftable(self._pgresult_ptr, column_number)
+
+ def ftablecol(self, int column_number) -> int:
+ return libpq.PQftablecol(self._pgresult_ptr, column_number)
+
+ def fformat(self, int column_number) -> int:
+ return libpq.PQfformat(self._pgresult_ptr, column_number)
+
+ def ftype(self, int column_number) -> int:
+ return libpq.PQftype(self._pgresult_ptr, column_number)
+
+ def fmod(self, int column_number) -> int:
+ return libpq.PQfmod(self._pgresult_ptr, column_number)
+
+ def fsize(self, int column_number) -> int:
+ return libpq.PQfsize(self._pgresult_ptr, column_number)
+
+ @property
+ def binary_tuples(self) -> int:
+ return libpq.PQbinaryTuples(self._pgresult_ptr)
+
+ def get_value(self, int row_number, int column_number) -> Optional[bytes]:
+ cdef int crow = row_number
+ cdef int ccol = column_number
+ cdef int length = libpq.PQgetlength(self._pgresult_ptr, crow, ccol)
+ cdef char *v
+ if length:
+ v = libpq.PQgetvalue(self._pgresult_ptr, crow, ccol)
+ # TODO: avoid copy
+ return v[:length]
+ else:
+ if libpq.PQgetisnull(self._pgresult_ptr, crow, ccol):
+ return None
+ else:
+ return b""
+
+ @property
+ def nparams(self) -> int:
+ return libpq.PQnparams(self._pgresult_ptr)
+
+ def param_type(self, int param_number) -> int:
+ return libpq.PQparamtype(self._pgresult_ptr, param_number)
+
+ @property
+ def command_status(self) -> Optional[bytes]:
+ cdef char *rv = libpq.PQcmdStatus(self._pgresult_ptr)
+ if rv is not NULL:
+ return rv
+ else:
+ return None
+
+ @property
+ def command_tuples(self) -> Optional[int]:
+ cdef char *rv = libpq.PQcmdTuples(self._pgresult_ptr)
+ if rv is NULL:
+ return None
+ cdef bytes brv = rv
+ return int(brv) if brv else None
+
+ @property
+ def oid_value(self) -> int:
+ return libpq.PQoidValue(self._pgresult_ptr)
+
+ def set_attributes(self, descriptions: List[PGresAttDesc]):
+ cdef Py_ssize_t num = len(descriptions)
+ cdef libpq.PGresAttDesc *attrs = <libpq.PGresAttDesc *>PyMem_Malloc(
+ num * sizeof(libpq.PGresAttDesc))
+
+ for i in range(num):
+ descr = descriptions[i]
+ attrs[i].name = descr.name
+ attrs[i].tableid = descr.tableid
+ attrs[i].columnid = descr.columnid
+ attrs[i].format = descr.format
+ attrs[i].typid = descr.typid
+ attrs[i].typlen = descr.typlen
+ attrs[i].atttypmod = descr.atttypmod
+
+ cdef int res = libpq.PQsetResultAttrs(self._pgresult_ptr, <int>num, attrs)
+ PyMem_Free(attrs)
+ if (res == 0):
+ raise e.OperationalError("PQsetResultAttrs failed")
diff --git a/psycopg_c/psycopg_c/pq/pqbuffer.pyx b/psycopg_c/psycopg_c/pq/pqbuffer.pyx
new file mode 100644
index 0000000..eb5d648
--- /dev/null
+++ b/psycopg_c/psycopg_c/pq/pqbuffer.pyx
@@ -0,0 +1,111 @@
+"""
+PQbuffer object implementation.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+from cpython.bytes cimport PyBytes_AsStringAndSize
+from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE
+from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release
+
+
+@cython.freelist(32)
+cdef class PQBuffer:
+ """
+ Wrap a chunk of memory allocated by the libpq and expose it as memoryview.
+ """
+ @staticmethod
+ cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length):
+ cdef PQBuffer rv = PQBuffer.__new__(PQBuffer)
+ rv.buf = buf
+ rv.len = length
+ return rv
+
+ def __cinit__(self):
+ self.buf = NULL
+ self.len = 0
+
+ def __dealloc__(self):
+ if self.buf:
+ libpq.PQfreemem(self.buf)
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ f"({bytes(self)})"
+ )
+
+ def __getbuffer__(self, Py_buffer *buffer, int flags):
+ buffer.buf = self.buf
+ buffer.obj = self
+ buffer.len = self.len
+ buffer.itemsize = sizeof(unsigned char)
+ buffer.readonly = 1
+ buffer.ndim = 1
+ buffer.format = NULL # unsigned char
+ buffer.shape = &self.len
+ buffer.strides = NULL
+ buffer.suboffsets = NULL
+ buffer.internal = NULL
+
+ def __releasebuffer__(self, Py_buffer *buffer):
+ pass
+
+
+@cython.freelist(32)
+cdef class ViewBuffer:
+ """
+ Wrap a chunk of memory owned by a different object.
+ """
+ @staticmethod
+ cdef ViewBuffer _from_buffer(
+ object obj, unsigned char *buf, Py_ssize_t length
+ ):
+ cdef ViewBuffer rv = ViewBuffer.__new__(ViewBuffer)
+ rv.obj = obj
+ rv.buf = buf
+ rv.len = length
+ return rv
+
+ def __cinit__(self):
+ self.buf = NULL
+ self.len = 0
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+ f"({bytes(self)})"
+ )
+
+ def __getbuffer__(self, Py_buffer *buffer, int flags):
+ buffer.buf = self.buf
+ buffer.obj = self
+ buffer.len = self.len
+ buffer.itemsize = sizeof(unsigned char)
+ buffer.readonly = 1
+ buffer.ndim = 1
+ buffer.format = NULL # unsigned char
+ buffer.shape = &self.len
+ buffer.strides = NULL
+ buffer.suboffsets = NULL
+ buffer.internal = NULL
+
+ def __releasebuffer__(self, Py_buffer *buffer):
+ pass
+
+
+cdef int _buffer_as_string_and_size(
+ data: "Buffer", char **ptr, Py_ssize_t *length
+) except -1:
+ cdef Py_buffer buf
+
+ if isinstance(data, bytes):
+ PyBytes_AsStringAndSize(data, ptr, length)
+ elif PyObject_CheckBuffer(data):
+ PyObject_GetBuffer(data, &buf, PyBUF_SIMPLE)
+ ptr[0] = <char *>buf.buf
+ length[0] = buf.len
+ PyBuffer_Release(&buf)
+ else:
+ raise TypeError(f"bytes or buffer expected, got {type(data)}")
diff --git a/psycopg_c/psycopg_c/py.typed b/psycopg_c/psycopg_c/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg_c/psycopg_c/py.typed
diff --git a/psycopg_c/psycopg_c/types/array.pyx b/psycopg_c/psycopg_c/types/array.pyx
new file mode 100644
index 0000000..9abaef9
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/array.pyx
@@ -0,0 +1,276 @@
+"""
+C optimized functions to manipulate arrays
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import cython
+
+from libc.stdint cimport int32_t, uint32_t
+from libc.string cimport memset, strchr
+from cpython.mem cimport PyMem_Realloc, PyMem_Free
+from cpython.ref cimport Py_INCREF
+from cpython.list cimport PyList_New,PyList_Append, PyList_GetSlice
+from cpython.list cimport PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE
+from cpython.object cimport PyObject
+
+from psycopg_c.pq cimport _buffer_as_string_and_size
+from psycopg_c.pq.libpq cimport Oid
+from psycopg_c._psycopg cimport endian
+
+from psycopg import errors as e
+
+cdef extern from *:
+ """
+/* Defined in PostgreSQL in src/include/utils/array.h */
+#define MAXDIM 6
+ """
+ const int MAXDIM
+
+
+cdef class ArrayLoader(_CRecursiveLoader):
+
+ format = PQ_TEXT
+ base_oid = 0
+ delimiter = b","
+
+ cdef PyObject *row_loader
+ cdef char cdelim
+
+ # A memory area which used to unescape elements.
+ # Keep it here to avoid a malloc per element and to set up exceptions
+ # to make sure to free it on error.
+ cdef char *scratch
+ cdef size_t sclen
+
+ cdef object cload(self, const char *data, size_t length):
+ if self.cdelim == b"\x00":
+ self.row_loader = self._tx._c_get_loader(
+ <PyObject *>self.base_oid, <PyObject *>PQ_TEXT)
+ self.cdelim = self.delimiter[0]
+
+ return _array_load_text(
+ data, length, self.row_loader, self.cdelim,
+ &(self.scratch), &(self.sclen))
+
+ def __dealloc__(self):
+ PyMem_Free(self.scratch)
+
+
+@cython.final
+cdef class ArrayBinaryLoader(_CRecursiveLoader):
+
+ format = PQ_BINARY
+
+ cdef PyObject *row_loader
+
+ cdef object cload(self, const char *data, size_t length):
+ rv = _array_load_binary(data, length, self._tx, &(self.row_loader))
+ return rv
+
+
+cdef object _array_load_text(
+ const char *buf, size_t length, PyObject *row_loader, char cdelim,
+ char **scratch, size_t *sclen
+):
+ if length == 0:
+ raise e.DataError("malformed array: empty data")
+
+ cdef const char *end = buf + length
+
+ # Remove the dimensions information prefix (``[...]=``)
+ if buf[0] == b"[":
+ buf = strchr(buf + 1, b'=')
+ if buf == NULL:
+ raise e.DataError("malformed array: no '=' after dimension information")
+ buf += 1
+
+ # TODO: further optimization: pre-scan the array to find the array
+ # dimensions, so that we can preallocate the list sized instead of calling
+ # append, which is the dominating operation
+
+ cdef list stack = []
+ cdef list a = []
+ rv = a
+ cdef PyObject *tmp
+
+ cdef CLoader cloader = None
+ cdef object pyload = None
+ if (<RowLoader>row_loader).cloader is not None:
+ cloader = (<RowLoader>row_loader).cloader
+ else:
+ pyload = (<RowLoader>row_loader).loadfunc
+
+ while buf < end:
+ if buf[0] == b'{':
+ if stack:
+ tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1)
+ PyList_Append(<object>tmp, a)
+ PyList_Append(stack, a)
+ a = []
+ buf += 1
+
+ elif buf[0] == b'}':
+ if not stack:
+ raise e.DataError("malformed array: unexpected '}'")
+ rv = stack.pop()
+ buf += 1
+
+ elif buf[0] == cdelim:
+ buf += 1
+
+ else:
+ v = _parse_token(
+ &buf, end, cdelim, scratch, sclen, cloader, pyload)
+ if not stack:
+ raise e.DataError("malformed array: missing initial '{'")
+ tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1)
+ PyList_Append(<object>tmp, v)
+
+ return rv
+
+
+cdef object _parse_token(
+ const char **bufptr, const char *bufend, char cdelim,
+ char **scratch, size_t *sclen, CLoader cloader, object load
+):
+ cdef const char *start = bufptr[0]
+ cdef int has_quotes = start[0] == b'"'
+ cdef int quoted = has_quotes
+ cdef int num_escapes = 0
+ cdef int escaped = 0
+
+ if has_quotes:
+ start += 1
+ cdef const char *end = start
+
+ while end < bufend:
+ if (end[0] == cdelim or end[0] == b'}') and not quoted:
+ break
+ elif end[0] == b'\\' and not escaped:
+ num_escapes += 1
+ escaped = 1
+ end += 1
+ continue
+ elif end[0] == b'"' and not escaped:
+ quoted = 0
+ escaped = 0
+ end += 1
+ else:
+ raise e.DataError("malformed array: hit the end of the buffer")
+
+ # Return the new position for the buffer
+ bufptr[0] = end
+ if has_quotes:
+ end -= 1
+
+ cdef int length = (end - start)
+ if length == 4 and not has_quotes \
+ and start[0] == b'N' and start[1] == b'U' \
+ and start[2] == b'L' and start[3] == b'L':
+ return None
+
+ cdef const char *src
+ cdef char *tgt
+ cdef size_t unesclen
+
+ if not num_escapes:
+ if cloader is not None:
+ return cloader.cload(start, length)
+ else:
+ b = start[:length]
+ return load(b)
+
+ else:
+ unesclen = length - num_escapes + 1
+ if unesclen > sclen[0]:
+ scratch[0] = <char *>PyMem_Realloc(scratch[0], unesclen)
+ sclen[0] = unesclen
+
+ src = start
+ tgt = scratch[0]
+ while src < end:
+ if src[0] == b'\\':
+ src += 1
+ tgt[0] = src[0]
+ src += 1
+ tgt += 1
+
+ tgt[0] = b'\x00'
+
+ if cloader is not None:
+ return cloader.cload(scratch[0], length - num_escapes)
+ else:
+ b = scratch[0][:length - num_escapes]
+ return load(b)
+
+
+@cython.cdivision(True)
+cdef object _array_load_binary(
+ const char *buf, size_t length, Transformer tx, PyObject **row_loader_ptr
+):
+ # head is ndims, hasnull, elem oid
+ cdef uint32_t *buf32 = <uint32_t *>buf
+ cdef int ndims = endian.be32toh(buf32[0])
+
+ if ndims <= 0:
+ return []
+ elif ndims > MAXDIM:
+ raise e.DataError(
+ r"unexpected number of dimensions %s exceeding the maximum allowed %s"
+ % (ndims, MAXDIM)
+ )
+
+ cdef object oid
+ if row_loader_ptr[0] == NULL:
+ oid = <Oid>endian.be32toh(buf32[2])
+ row_loader_ptr[0] = tx._c_get_loader(<PyObject *>oid, <PyObject *>PQ_BINARY)
+
+ cdef Py_ssize_t[MAXDIM] dims
+ cdef int i
+ for i in range(ndims):
+ # Every dimension is dim, lower bound
+ dims[i] = endian.be32toh(buf32[3 + 2 * i])
+
+ buf += (3 + 2 * ndims) * sizeof(uint32_t)
+ out = _array_load_binary_rec(ndims, dims, &buf, row_loader_ptr[0])
+ return out
+
+
+cdef object _array_load_binary_rec(
+ Py_ssize_t ndims, Py_ssize_t *dims, const char **bufptr, PyObject *row_loader
+):
+ cdef const char *buf
+ cdef int i
+ cdef int32_t size
+ cdef object val
+
+ cdef Py_ssize_t nelems = dims[0]
+ cdef list out = PyList_New(nelems)
+
+ if ndims == 1:
+ buf = bufptr[0]
+ for i in range(nelems):
+ size = <int32_t>endian.be32toh((<uint32_t *>buf)[0])
+ buf += sizeof(uint32_t)
+ if size == -1:
+ val = None
+ else:
+ if (<RowLoader>row_loader).cloader is not None:
+ val = (<RowLoader>row_loader).cloader.cload(buf, size)
+ else:
+ val = (<RowLoader>row_loader).loadfunc(buf[:size])
+ buf += size
+
+ Py_INCREF(val)
+ PyList_SET_ITEM(out, i, val)
+
+ bufptr[0] = buf
+
+ else:
+ for i in range(nelems):
+ val = _array_load_binary_rec(ndims - 1, dims + 1, bufptr, row_loader)
+ Py_INCREF(val)
+ PyList_SET_ITEM(out, i, val)
+
+ return out
diff --git a/psycopg_c/psycopg_c/types/bool.pyx b/psycopg_c/psycopg_c/types/bool.pyx
new file mode 100644
index 0000000..86cf88e
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/bool.pyx
@@ -0,0 +1,78 @@
+"""
+Cython adapters for boolean.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+
+@cython.final
+cdef class BoolDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.BOOL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *buf = CDumper.ensure_size(rv, offset, 1)
+
+ # Fast paths, just a pointer comparison
+ if obj is True:
+ buf[0] = b"t"
+ elif obj is False:
+ buf[0] = b"f"
+ elif obj:
+ buf[0] = b"t"
+ else:
+ buf[0] = b"f"
+
+ return 1
+
+ def quote(self, obj: bool) -> bytes:
+ if obj is True:
+ return b"true"
+ elif obj is False:
+ return b"false"
+ else:
+ return b"true" if obj else b"false"
+
+
+@cython.final
+cdef class BoolBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.BOOL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *buf = CDumper.ensure_size(rv, offset, 1)
+
+ # Fast paths, just a pointer comparison
+ if obj is True:
+ buf[0] = b"\x01"
+ elif obj is False:
+ buf[0] = b"\x00"
+ elif obj:
+ buf[0] = b"\x01"
+ else:
+ buf[0] = b"\x00"
+
+ return 1
+
+
+@cython.final
+cdef class BoolLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ # this creates better C than `return data[0] == b't'`
+ return True if data[0] == b't' else False
+
+
+@cython.final
+cdef class BoolBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return True if data[0] else False
diff --git a/psycopg_c/psycopg_c/types/datetime.pyx b/psycopg_c/psycopg_c/types/datetime.pyx
new file mode 100644
index 0000000..51e7dcf
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/datetime.pyx
@@ -0,0 +1,1136 @@
+"""
+Cython adapters for date/time types.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from libc.string cimport memset, strchr
+from cpython cimport datetime as cdt
+from cpython.dict cimport PyDict_GetItem
+from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
+
+cdef extern from "Python.h":
+ const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL
+ object PyTimeZone_FromOffset(object offset)
+
+cdef extern from *:
+ """
+/* Multipliers from fraction of seconds to microseconds */
+static int _uspad[] = {0, 100000, 10000, 1000, 100, 10, 1};
+ """
+ cdef int *_uspad
+
+from datetime import date, time, timedelta, datetime, timezone
+
+from psycopg_c._psycopg cimport endian
+
+from psycopg import errors as e
+from psycopg._compat import ZoneInfo
+
+
+# Initialise the datetime C API
+cdt.import_datetime()
+
+cdef enum:
+ ORDER_YMD = 0
+ ORDER_DMY = 1
+ ORDER_MDY = 2
+ ORDER_PGDM = 3
+ ORDER_PGMD = 4
+
+cdef enum:
+ INTERVALSTYLE_OTHERS = 0
+ INTERVALSTYLE_SQL_STANDARD = 1
+ INTERVALSTYLE_POSTGRES = 2
+
+cdef enum:
+ PG_DATE_EPOCH_DAYS = 730120 # date(2000, 1, 1).toordinal()
+ PY_DATE_MIN_DAYS = 1 # date.min.toordinal()
+
+cdef object date_toordinal = date.toordinal
+cdef object date_fromordinal = date.fromordinal
+cdef object datetime_astimezone = datetime.astimezone
+cdef object time_utcoffset = time.utcoffset
+cdef object timedelta_total_seconds = timedelta.total_seconds
+cdef object timezone_utc = timezone.utc
+cdef object pg_datetime_epoch = datetime(2000, 1, 1)
+cdef object pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc)
+
+cdef object _month_abbr = {
+ n: i
+ for i, n in enumerate(
+ b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1
+ )
+}
+
+
+@cython.final
+cdef class DateDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.DATE_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class DateBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.DATE_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int32_t days = PyObject_CallFunctionObjArgs(
+ date_toordinal, <PyObject *>obj, NULL)
+ days -= PG_DATE_EPOCH_DAYS
+ cdef int32_t *buf = <int32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int32_t))
+ buf[0] = endian.htobe32(days)
+ return sizeof(int32_t)
+
+
+cdef class _BaseTimeDumper(CDumper):
+
+ cpdef get_key(self, obj, format):
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ cpdef upgrade(self, obj: time, format):
+ raise NotImplementedError
+
+
+cdef class _BaseTimeTextDumper(_BaseTimeDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class TimeDumper(_BaseTimeTextDumper):
+
+ oid = oids.TIME_OID
+
+ cpdef upgrade(self, obj, format):
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+@cython.final
+cdef class TimeTzDumper(_BaseTimeTextDumper):
+
+ oid = oids.TIMETZ_OID
+
+
+@cython.final
+cdef class TimeBinaryDumper(_BaseTimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIME_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = cdt.time_microsecond(obj) + 1000000 * (
+ cdt.time_second(obj)
+ + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj))
+ )
+
+ cdef int64_t *buf = <int64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int64_t))
+ buf[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+ cpdef upgrade(self, obj, format):
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzBinaryDumper(self.cls)
+
+
+@cython.final
+cdef class TimeTzBinaryDumper(_BaseTimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMETZ_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = cdt.time_microsecond(obj) + 1_000_000 * (
+ cdt.time_second(obj)
+ + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj))
+ )
+
+ off = PyObject_CallFunctionObjArgs(time_utcoffset, <PyObject *>obj, NULL)
+ cdef int32_t offsec = int(PyObject_CallFunctionObjArgs(
+ timedelta_total_seconds, <PyObject *>off, NULL))
+
+ cdef char *buf = CDumper.ensure_size(
+ rv, offset, sizeof(int64_t) + sizeof(int32_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(-offsec)
+
+ return sizeof(int64_t) + sizeof(int32_t)
+
+
+cdef class _BaseDatetimeDumper(CDumper):
+
+ cpdef get_key(self, obj, format):
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ cpdef upgrade(self, obj: time, format):
+ raise NotImplementedError
+
+
+cdef class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
+ # the YYYY-MM-DD is always understood correctly.
+ cdef str s = str(obj)
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class DatetimeDumper(_BaseDatetimeTextDumper):
+
+ oid = oids.TIMESTAMPTZ_OID
+
+ cpdef upgrade(self, obj, format):
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzDumper(self.cls)
+
+
+@cython.final
+cdef class DatetimeNoTzDumper(_BaseDatetimeTextDumper):
+
+ oid = oids.TIMESTAMP_OID
+
+
+@cython.final
+cdef class DatetimeBinaryDumper(_BaseDatetimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMESTAMPTZ_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ delta = obj - pg_datetimetz_epoch
+
+ cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * (
+ 86_400 * <int64_t>cdt.timedelta_days(delta)
+ + <int64_t>cdt.timedelta_seconds(delta))
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+ cpdef upgrade(self, obj, format):
+ if obj.tzinfo:
+ return self
+ else:
+ return DatetimeNoTzBinaryDumper(self.cls)
+
+
+@cython.final
+cdef class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
+
+ format = PQ_BINARY
+ oid = oids.TIMESTAMP_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ delta = obj - pg_datetime_epoch
+
+ cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * (
+ 86_400 * <int64_t>cdt.timedelta_days(delta)
+ + <int64_t>cdt.timedelta_seconds(delta))
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ return sizeof(int64_t)
+
+
+@cython.final
+cdef class TimedeltaDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.INTERVAL_OID
+ cdef int _style
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_intervalstyle(self._pgconn)
+ if ds[0] == b's': # sql_standard
+ self._style = INTERVALSTYLE_SQL_STANDARD
+ else: # iso_8601, postgres, postgres_verbose
+ self._style = INTERVALSTYLE_OTHERS
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ cdef str s
+ if self._style == INTERVALSTYLE_OTHERS:
+ # The comma is parsed ok by PostgreSQL but it's not documented
+ # and it seems brittle to rely on it. CRDB doesn't consume it well.
+ s = str(obj).replace(",", "")
+ else:
+ # sql_standard format needs explicit signs
+ # otherwise -1 day 1 sec will mean -1 sec
+ s = "%+d day %+d second %+d microsecond" % (
+ obj.days, obj.seconds, obj.microseconds)
+
+ src = PyUnicode_AsUTF8AndSize(s, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class TimedeltaBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INTERVAL_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t micros = (
+ 1_000_000 * <int64_t>cdt.timedelta_seconds(obj)
+ + cdt.timedelta_microseconds(obj))
+ cdef int32_t days = cdt.timedelta_days(obj)
+
+ cdef char *buf = CDumper.ensure_size(
+ rv, offset, sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t))
+ (<int64_t *>buf)[0] = endian.htobe64(micros)
+ (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(days)
+ (<int32_t *>(buf + sizeof(int64_t) + sizeof(int32_t)))[0] = 0
+
+ return sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t)
+
+
+@cython.final
+cdef class DateLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ elif ds[0] == b'G': # German
+ self._order = ORDER_DMY
+ elif ds[0] == b'S': # SQL, DMY / MDY
+ self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY
+ elif ds[0] == b'P': # Postgres, DMY / MDY
+ self._order = ORDER_DMY if ds[10] == b'D' else ORDER_MDY
+ else:
+ raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ cdef object _error_date(self, const char *data, str msg):
+ s = bytes(data).decode("utf8", "replace")
+ if s == "infinity" or len(s.split()[0]) > 10:
+ raise e.DataError(f"date too large (after year 10K): {s!r}") from None
+ elif s == "-infinity" or "BC" in s:
+ raise e.DataError(f"date too small (before year 1): {s!r}") from None
+ else:
+ raise e.DataError(f"can't parse date {s!r}: {msg}") from None
+
+ cdef object cload(self, const char *data, size_t length):
+ if length != 10:
+ self._error_date(data, "unexpected length")
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+
+ cdef const char *ptr
+ cdef const char *end = data + length
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse date {s!r}")
+
+ try:
+ if self._order == ORDER_YMD:
+ return cdt.date_new(vals[0], vals[1], vals[2])
+ elif self._order == ORDER_DMY:
+ return cdt.date_new(vals[2], vals[1], vals[0])
+ else:
+ return cdt.date_new(vals[2], vals[0], vals[1])
+ except ValueError as ex:
+ self._error_date(data, str(ex))
+
+
+@cython.final
+cdef class DateBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int days = endian.be32toh((<uint32_t *>data)[0])
+ cdef object pydays = days + PG_DATE_EPOCH_DAYS
+ try:
+ return PyObject_CallFunctionObjArgs(
+ date_fromordinal, <PyObject *>pydays, NULL)
+ except ValueError:
+ if days < PY_DATE_MIN_DAYS:
+ raise e.DataError("date too small (before year 1)") from None
+ else:
+ raise e.DataError("date too large (after year 10K)") from None
+
+
+@cython.final
+cdef class TimeLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+ cdef const char *end = data + length
+
+ # Parse the first 3 groups of digits
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse time {s!r}")
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ try:
+ return cdt.time_new(vals[0], vals[1], vals[2], us, None)
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse time {s!r}: {ex}") from None
+
+
+@cython.final
+cdef class TimeBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int h, m, s, us
+
+ with cython.cdivision(True):
+ us = val % 1_000_000
+ val //= 1_000_000
+
+ s = val % 60
+ val //= 60
+
+ m = val % 60
+ h = <int>(val // 60)
+
+ try:
+ return cdt.time_new(h, m, s, us, None)
+ except ValueError:
+ raise e.DataError(
+ f"time not supported by Python: hour={h}"
+ ) from None
+
+
+@cython.final
+cdef class TimetzLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+ cdef const char *end = data + length
+
+ # Parse the first 3 groups of digits (time)
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}")
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Parse the timezone
+ cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}")
+
+ tz = _timezone_from_seconds(offsecs)
+ try:
+ return cdt.time_new(vals[0], vals[1], vals[2], us, tz)
+ except ValueError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse timetz {s!r}: {ex}") from None
+
+
+@cython.final
+cdef class TimetzBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int32_t off = endian.be32toh((<uint32_t *>(data + sizeof(int64_t)))[0])
+ cdef int h, m, s, us
+
+ with cython.cdivision(True):
+ us = val % 1_000_000
+ val //= 1_000_000
+
+ s = val % 60
+ val //= 60
+
+ m = val % 60
+ h = <int>(val // 60)
+
+ tz = _timezone_from_seconds(-off)
+ try:
+ return cdt.time_new(h, m, s, us, tz)
+ except ValueError:
+ raise e.DataError(
+ f"time not supported by Python: hour={h}"
+ ) from None
+
+
+@cython.final
+cdef class TimestampLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ elif ds[0] == b'G': # German
+ self._order = ORDER_DMY
+ elif ds[0] == b'S': # SQL, DMY / MDY
+ self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY
+ elif ds[0] == b'P': # Postgres, DMY / MDY
+ self._order = ORDER_PGDM if ds[10] == b'D' else ORDER_PGMD
+ else:
+ raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef const char *end = data + length
+ if end[-1] == b'C': # ends with BC
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
+ return self._cload_pg(data, end)
+
+ cdef int vals[6]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+
+ # Parse the first 6 groups of digits (date and time)
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Resolve the YMD order
+ cdef int y, m, d
+ if self._order == ORDER_YMD:
+ y, m, d = vals[0], vals[1], vals[2]
+ elif self._order == ORDER_DMY:
+ d, m, y = vals[0], vals[1], vals[2]
+ else: # self._order == ORDER_MDY
+ m, d, y = vals[0], vals[1], vals[2]
+
+ try:
+ return cdt.datetime_new(
+ y, m, d, vals[3], vals[4], vals[5], us, None)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+ cdef object _cload_pg(self, const char *data, const char *end):
+ cdef int vals[4]
+ memset(vals, 0, sizeof(vals))
+ cdef const char *ptr
+
+ # Find Wed Jun 02 or Wed 02 Jun
+ cdef char *seps[3]
+ seps[0] = strchr(data, b' ')
+ seps[1] = strchr(seps[0] + 1, b' ') if seps[0] != NULL else NULL
+ seps[2] = strchr(seps[1] + 1, b' ') if seps[1] != NULL else NULL
+ if seps[2] == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the following 3 groups of digits (time)
+ ptr = _parse_date_values(seps[2] + 1, end, vals, 3)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Parse the year
+ ptr = _parse_date_values(ptr + 1, end, vals + 3, 1)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Resolve the MD order
+ cdef int m, d
+ try:
+ if self._order == ORDER_PGDM:
+ d = int(seps[0][1 : seps[1] - seps[0]])
+ m = _month_abbr[seps[1][1 : seps[2] - seps[1]]]
+ else: # self._order == ORDER_PGMD
+ m = _month_abbr[seps[0][1 : seps[1] - seps[0]]]
+ d = int(seps[1][1 : seps[2] - seps[1]])
+ except (KeyError, ValueError) as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+ try:
+ return cdt.datetime_new(
+ vals[3], m, d, vals[0], vals[1], vals[2], us, None)
+ except ValueError as ex:
+ raise _get_timestamp_load_error(self._pgconn, data, ex) from None
+
+
+@cython.final
+cdef class TimestampBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int64_t micros, secs, days
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ secs = aval // 1_000_000
+ micros = aval % 1_000_000
+
+ days = secs // 86_400
+ secs %= 86_400
+
+ try:
+ delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros)
+ if val > 0:
+ return pg_datetime_epoch + delta
+ else:
+ return pg_datetime_epoch - delta
+
+ except OverflowError:
+ if val <= 0:
+ raise e.DataError("timestamp too small (before year 1)") from None
+ else:
+ raise e.DataError("timestamp too large (after year 10K)") from None
+
+
+cdef class _BaseTimestamptzLoader(CLoader):
+ cdef object _time_zone
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+ self._time_zone = _timezone_from_connection(self._pgconn)
+
+
+@cython.final
+cdef class TimestamptzLoader(_BaseTimestamptzLoader):
+
+ format = PQ_TEXT
+ cdef int _order
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_datestyle(self._pgconn)
+ if ds[0] == b'I': # ISO
+ self._order = ORDER_YMD
+ else: # Not true, but any non-YMD will do.
+ self._order = ORDER_DMY
+
+ cdef object cload(self, const char *data, size_t length):
+ if self._order != ORDER_YMD:
+ return self._cload_notimpl(data, length)
+
+ cdef const char *end = data + length
+ if end[-1] == b'C': # ends with BC
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ cdef int vals[6]
+ memset(vals, 0, sizeof(vals))
+
+ # Parse the first 6 groups of digits (date and time)
+ cdef const char *ptr
+ ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ # Parse the microseconds
+ cdef int us = 0
+ if ptr[0] == b".":
+ ptr = _parse_micros(ptr + 1, &us)
+
+ # Resolve the YMD order
+ cdef int y, m, d
+ if self._order == ORDER_YMD:
+ y, m, d = vals[0], vals[1], vals[2]
+ elif self._order == ORDER_DMY:
+ d, m, y = vals[0], vals[1], vals[2]
+ else: # self._order == ORDER_MDY
+ m, d, y = vals[0], vals[1], vals[2]
+
+ # Parse the timezone
+ cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
+ if ptr == NULL:
+ raise _get_timestamp_load_error(self._pgconn, data) from None
+
+ tzoff = cdt.timedelta_new(0, offsecs, 0)
+
+ # The return value is a datetime with the timezone of the connection
+ # (in order to be consistent with the binary loader, which is the only
+ # thing it can return). So create a temporary datetime object, in utc,
+ # shift it by the offset parsed from the timestamp, and then move it to
+ # the connection timezone.
+ dt = None
+ try:
+ dt = cdt.datetime_new(
+ y, m, d, vals[3], vals[4], vals[5], us, timezone_utc)
+ dt -= tzoff
+ return PyObject_CallFunctionObjArgs(datetime_astimezone,
+ <PyObject *>dt, <PyObject *>self._time_zone, NULL)
+ except OverflowError as ex:
+ # If we have created the temporary 'dt' it means that we have a
+ # datetime close to max, the shift pushed it past max, overflowing.
+ # In this case return the datetime in a fixed offset timezone.
+ if dt is not None:
+ return dt.replace(tzinfo=timezone(tzoff))
+ else:
+ ex1 = ex
+ except ValueError as ex:
+ ex1 = ex
+
+ raise _get_timestamp_load_error(self._pgconn, data, ex1) from None
+
+ cdef object _cload_notimpl(self, const char *data, size_t length):
+ s = bytes(data)[:length].decode("utf8", "replace")
+ ds = _get_datestyle(self._pgconn).decode()
+ raise NotImplementedError(
+ f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
+ )
+
+
+@cython.final
+cdef class TimestamptzBinaryLoader(_BaseTimestamptzLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int64_t micros, secs, days
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ secs = aval // 1_000_000
+ micros = aval % 1_000_000
+
+ days = secs // 86_400
+ secs %= 86_400
+
+ try:
+ delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros)
+ if val > 0:
+ dt = pg_datetimetz_epoch + delta
+ else:
+ dt = pg_datetimetz_epoch - delta
+ return PyObject_CallFunctionObjArgs(datetime_astimezone,
+ <PyObject *>dt, <PyObject *>self._time_zone, NULL)
+
+ except OverflowError:
+ # If we were asked about a timestamp which would overflow in UTC,
+ # but not in the desired timezone (e.g. datetime.max at Chicago
+ # timezone) we can still save the day by shifting the value by the
+ # timezone offset and then replacing the timezone.
+ if self._time_zone is not None:
+ utcoff = self._time_zone.utcoffset(
+ datetime.min if val < 0 else datetime.max
+ )
+ if utcoff:
+ usoff = 1_000_000 * int(utcoff.total_seconds())
+ try:
+ ts = pg_datetime_epoch + timedelta(
+ microseconds=val + usoff
+ )
+ except OverflowError:
+ pass # will raise downstream
+ else:
+ return ts.replace(tzinfo=self._time_zone)
+
+ if val <= 0:
+ raise e.DataError(
+ "timestamp too small (before year 1)"
+ ) from None
+ else:
+ raise e.DataError(
+ "timestamp too large (after year 10K)"
+ ) from None
+
+
+@cython.final
+cdef class IntervalLoader(CLoader):
+
+ format = PQ_TEXT
+ cdef int _style
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ cdef const char *ds = _get_intervalstyle(self._pgconn)
+ if ds[0] == b'p' and ds[8] == 0: # postgres
+ self._style = INTERVALSTYLE_POSTGRES
+ else: # iso_8601, sql_standard, postgres_verbose
+ self._style = INTERVALSTYLE_OTHERS
+
+ cdef object cload(self, const char *data, size_t length):
+ if self._style == INTERVALSTYLE_OTHERS:
+ return self._cload_notimpl(data, length)
+
+ cdef int days = 0, secs = 0, us = 0
+ cdef char sign
+ cdef int val
+ cdef const char *ptr = data
+ cdef const char *sep
+ cdef const char *end = ptr + length
+
+ # If there are spaces, there is a [+|-]n [days|months|years]
+ while True:
+ if ptr[0] == b'-' or ptr[0] == b'+':
+ sign = ptr[0]
+ ptr += 1
+ else:
+ sign = 0
+
+ sep = strchr(ptr, b' ')
+ if sep == NULL or sep > end:
+ break
+
+ val = 0
+ ptr = _parse_date_values(ptr, end, &val, 1)
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ if sign == b'-':
+ val = -val
+
+ if ptr[1] == b'y':
+ days = 365 * val
+ elif ptr[1] == b'm':
+ days = 30 * val
+ elif ptr[1] == b'd':
+ days = val
+ else:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ # Skip the date part word.
+ ptr = strchr(ptr + 1, b' ')
+ if ptr != NULL and ptr < end:
+ ptr += 1
+ else:
+ break
+
+ # Parse the time part. An eventual sign was already consumed in the loop
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+ if ptr != NULL:
+ ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}")
+
+ secs = vals[2] + 60 * (vals[1] + 60 * vals[0])
+
+ if ptr[0] == b'.':
+ ptr = _parse_micros(ptr + 1, &us)
+
+ if sign == b'-':
+ secs = -secs
+ us = -us
+
+ try:
+ return cdt.timedelta_new(days, secs, us)
+ except OverflowError as ex:
+ s = bytes(data).decode("utf8", "replace")
+ raise e.DataError(f"can't parse interval {s!r}: {ex}") from None
+
+ cdef object _cload_notimpl(self, const char *data, size_t length):
+ s = bytes(data).decode("utf8", "replace")
+ style = _get_intervalstyle(self._pgconn).decode()
+ raise NotImplementedError(
+ f"can't parse interval with IntervalStyle {style!r}: {s!r}"
+ )
+
+
+@cython.final
+cdef class IntervalBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef int64_t val = endian.be64toh((<uint64_t *>data)[0])
+ cdef int32_t days = endian.be32toh(
+ (<uint32_t *>(data + sizeof(int64_t)))[0])
+ cdef int32_t months = endian.be32toh(
+ (<uint32_t *>(data + sizeof(int64_t) + sizeof(int32_t)))[0])
+
+ cdef int years
+ with cython.cdivision(True):
+ if months > 0:
+ years = months // 12
+ months %= 12
+ days += 30 * months + 365 * years
+ elif months < 0:
+ months = -months
+ years = months // 12
+ months %= 12
+ days -= 30 * months + 365 * years
+
+ # Work only with positive values as the cdivision behaves differently
+ # with negative values, and cdivision=False adds overhead.
+ cdef int64_t aval = val if val >= 0 else -val
+ cdef int us, ussecs, usdays
+
+ # Group the micros in biggers stuff or timedelta_new might overflow
+ with cython.cdivision(True):
+ ussecs = <int>(aval // 1_000_000)
+ us = aval % 1_000_000
+
+ usdays = ussecs // 86_400
+ ussecs %= 86_400
+
+ if val < 0:
+ ussecs = -ussecs
+ usdays = -usdays
+ us = -us
+
+ try:
+ return cdt.timedelta_new(days + usdays, ussecs, us)
+ except OverflowError as ex:
+ raise e.DataError(f"can't parse interval: {ex}")
+
+
+cdef const char *_parse_date_values(
+ const char *ptr, const char *end, int *vals, int nvals
+):
+ """
+ Parse *nvals* numeric values separated by non-numeric chars.
+
+ Write the result in the *vals* array (assumed zeroed) starting from *start*.
+
+ Return the pointer at the separator after the final digit.
+ """
+ cdef int ival = 0
+ while ptr < end:
+ if b'0' <= ptr[0] <= b'9':
+ vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0')
+ else:
+ ival += 1
+ if ival >= nvals:
+ break
+
+ ptr += 1
+
+ return ptr
+
+
+cdef const char *_parse_micros(const char *start, int *us):
+ """
+ Parse microseconds from a string.
+
+ Micros are assumed up to 6 digit chars separated by a non-digit.
+
+ Return the pointer at the separator after the final digit.
+ """
+ cdef const char *ptr = start
+ while ptr[0]:
+ if b'0' <= ptr[0] <= b'9':
+ us[0] = us[0] * 10 + (ptr[0] - <char>b'0')
+ else:
+ break
+
+ ptr += 1
+
+ # Pad the fraction of second to get millis
+ if us[0] and ptr - start < 6:
+ us[0] *= _uspad[ptr - start]
+
+ return ptr
+
+
+cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end):
+ """
+ Parse a timezone from a string, return Python timezone object.
+
+ Modify the buffer pointer to point at the first character after the
+ timezone parsed. In case of parse error make it NULL.
+ """
+ cdef const char *ptr = bufptr[0]
+ cdef char sgn = ptr[0]
+
+ # Parse at most three groups of digits
+ cdef int vals[3]
+ memset(vals, 0, sizeof(vals))
+
+ ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals))
+ if ptr == NULL:
+ return 0
+
+ cdef int off = 60 * (60 * vals[0] + vals[1]) + vals[2]
+ return -off if sgn == b"-" else off
+
+
+cdef object _timezone_from_seconds(int sec, __cache={}):
+ cdef object pysec = sec
+ cdef PyObject *ptr = PyDict_GetItem(__cache, pysec)
+ if ptr != NULL:
+ return <object>ptr
+
+ delta = cdt.timedelta_new(0, sec, 0)
+ tz = timezone(delta)
+ __cache[pysec] = tz
+ return tz
+
+
+cdef object _get_timestamp_load_error(
+ pq.PGconn pgconn, const char *data, ex: Optional[Exception] = None
+):
+ s = bytes(data).decode("utf8", "replace")
+
+ def is_overflow(s):
+ if not s:
+ return False
+
+ ds = _get_datestyle(pgconn)
+ if not ds.startswith(b"P"): # Postgres
+ return len(s.split()[0]) > 10 # date is first token
+ else:
+ return len(s.split()[-1]) > 4 # year is last token
+
+ if s == "-infinity" or s.endswith("BC"):
+ return e.DataError("timestamp too small (before year 1): {s!r}")
+ elif s == "infinity" or is_overflow(s):
+ return e.DataError(f"timestamp too large (after year 10K): {s!r}")
+ else:
+ return e.DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}")
+
+
+cdef _timezones = {}
+_timezones[None] = timezone_utc
+_timezones[b"UTC"] = timezone_utc
+
+
+cdef object _timezone_from_connection(pq.PGconn pgconn):
+ """Return the Python timezone info of the connection's timezone."""
+ if pgconn is None:
+ return timezone_utc
+
+ cdef bytes tzname = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"TimeZone")
+ cdef PyObject *ptr = PyDict_GetItem(_timezones, tzname)
+ if ptr != NULL:
+ return <object>ptr
+
+ sname = tzname.decode() if tzname else "UTC"
+ try:
+ zi = ZoneInfo(sname)
+ except (KeyError, OSError):
+ logger.warning(
+ "unknown PostgreSQL timezone: %r; will use UTC", sname
+ )
+ zi = timezone_utc
+ except Exception as ex:
+ logger.warning(
+ "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
+ sname,
+ type(ex).__name__,
+ ex,
+ )
+ zi = timezone.utc
+
+ _timezones[tzname] = zi
+ return zi
+
+
+cdef const char *_get_datestyle(pq.PGconn pgconn):
+ cdef const char *ds
+ if pgconn is not None:
+ ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"DateStyle")
+ if ds is not NULL and ds[0]:
+ return ds
+
+ return b"ISO, DMY"
+
+
+cdef const char *_get_intervalstyle(pq.PGconn pgconn):
+ cdef const char *ds
+ if pgconn is not None:
+ ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"IntervalStyle")
+ if ds is not NULL and ds[0]:
+ return ds
+
+ return b"postgres"
diff --git a/psycopg_c/psycopg_c/types/numeric.pyx b/psycopg_c/psycopg_c/types/numeric.pyx
new file mode 100644
index 0000000..893bdc2
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/numeric.pyx
@@ -0,0 +1,715 @@
+"""
+Cython adapters for numeric types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+from libc.stdint cimport *
+from libc.string cimport memcpy, strlen
+from cpython.mem cimport PyMem_Free
+from cpython.dict cimport PyDict_GetItem, PyDict_SetItem
+from cpython.long cimport (
+ PyLong_FromString, PyLong_FromLong, PyLong_FromLongLong,
+ PyLong_FromUnsignedLong, PyLong_AsLongLong)
+from cpython.bytes cimport PyBytes_AsStringAndSize
+from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble
+from cpython.unicode cimport PyUnicode_DecodeUTF8
+
+from decimal import Decimal, Context, DefaultContext
+
+from psycopg_c._psycopg cimport endian
+from psycopg import errors as e
+from psycopg._wrappers import Int2, Int4, Int8, IntNumeric
+
+cdef extern from "Python.h":
+ # work around https://github.com/cython/cython/issues/3909
+ double PyOS_string_to_double(
+ const char *s, char **endptr, PyObject *overflow_exception) except? -1.0
+ char *PyOS_double_to_string(
+ double val, char format_code, int precision, int flags, int *ptype
+ ) except NULL
+ int Py_DTSF_ADD_DOT_0
+ long long PyLong_AsLongLongAndOverflow(object pylong, int *overflow) except? -1
+
+ # Missing in cpython/unicode.pxd
+ const char *PyUnicode_AsUTF8(object unicode) except NULL
+
+
+# defined in numutils.c
+cdef extern from *:
+ """
+int pg_lltoa(int64_t value, char *a);
+#define MAXINT8LEN 20
+ """
+ int pg_lltoa(int64_t value, char *a)
+ const int MAXINT8LEN
+
+
+cdef class _NumberDumper(CDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_int_to_text(obj, rv, offset)
+
+ def quote(self, obj) -> bytearray:
+ cdef Py_ssize_t length
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+ if obj >= 0:
+ length = self.cdump(obj, rv, 0)
+ else:
+ PyByteArray_Resize(rv, 23)
+ rv[0] = b' '
+ length = 1 + self.cdump(obj, rv, 1)
+
+ PyByteArray_Resize(rv, length)
+ return rv
+
+
+@cython.final
+cdef class Int2Dumper(_NumberDumper):
+
+ oid = oids.INT2_OID
+
+
+@cython.final
+cdef class Int4Dumper(_NumberDumper):
+
+ oid = oids.INT4_OID
+
+
+@cython.final
+cdef class Int8Dumper(_NumberDumper):
+
+ oid = oids.INT8_OID
+
+
+@cython.final
+cdef class IntNumericDumper(_NumberDumper):
+
+ oid = oids.NUMERIC_OID
+
+
+@cython.final
+cdef class Int2BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT2_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int16_t *buf = <int16_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int16_t))
+ cdef int16_t val = <int16_t>PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint16_t *ptvar = <uint16_t *>(&val)
+ buf[0] = endian.htobe16(ptvar[0])
+ return sizeof(int16_t)
+
+
+@cython.final
+cdef class Int4BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT4_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int32_t *buf = <int32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int32_t))
+ cdef int32_t val = <int32_t>PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint32_t *ptvar = <uint32_t *>(&val)
+ buf[0] = endian.htobe32(ptvar[0])
+ return sizeof(int32_t)
+
+
+@cython.final
+cdef class Int8BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.INT8_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef int64_t *buf = <int64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(int64_t))
+ cdef int64_t val = PyLong_AsLongLong(obj)
+ # swap bytes if needed
+ cdef uint64_t *ptvar = <uint64_t *>(&val)
+ buf[0] = endian.htobe64(ptvar[0])
+ return sizeof(int64_t)
+
+
+cdef extern from *:
+ """
+/* Ratio between number of bits required to store a number and number of pg
+ * decimal digits required (log(2) / log(10_000)).
+ */
+#define BIT_PER_PGDIGIT 0.07525749891599529
+
+/* decimal digits per Postgres "digit" */
+#define DEC_DIGITS 4
+
+#define NUMERIC_POS 0x0000
+#define NUMERIC_NEG 0x4000
+#define NUMERIC_NAN 0xC000
+#define NUMERIC_PINF 0xD000
+#define NUMERIC_NINF 0xF000
+"""
+ const double BIT_PER_PGDIGIT
+ const int DEC_DIGITS
+ const int NUMERIC_POS
+ const int NUMERIC_NEG
+ const int NUMERIC_NAN
+ const int NUMERIC_PINF
+ const int NUMERIC_NINF
+
+
+@cython.final
+cdef class IntNumericBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_int_to_numeric_binary(obj, rv, offset)
+
+
+cdef class IntDumper(_NumberDumper):
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ raise TypeError(
+ f"{type(self).__name__} is a dispatcher to other dumpers:"
+ " dump() is not supposed to be called"
+ )
+
+ cpdef get_key(self, obj, format):
+ cdef long long val
+ cdef int overflow
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if overflow:
+ return IntNumeric
+
+ if INT32_MIN <= obj <= INT32_MAX:
+ if INT16_MIN <= obj <= INT16_MAX:
+ return Int2
+ else:
+ return Int4
+ else:
+ if INT64_MIN <= obj <= INT64_MAX:
+ return Int8
+ else:
+ return IntNumeric
+
+ _int2_dumper = Int2Dumper
+ _int4_dumper = Int4Dumper
+ _int8_dumper = Int8Dumper
+ _int_numeric_dumper = IntNumericDumper
+
+ cpdef upgrade(self, obj, format):
+ cdef long long val
+ cdef int overflow
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if overflow:
+ return self._int_numeric_dumper(IntNumeric)
+
+ if INT32_MIN <= obj <= INT32_MAX:
+ if INT16_MIN <= obj <= INT16_MAX:
+ return self._int2_dumper(Int2)
+ else:
+ return self._int4_dumper(Int4)
+ else:
+ if INT64_MIN <= obj <= INT64_MAX:
+ return self._int8_dumper(Int8)
+ else:
+ return self._int_numeric_dumper(IntNumeric)
+
+
+@cython.final
+cdef class IntBinaryDumper(IntDumper):
+
+ format = PQ_BINARY
+
+ _int2_dumper = Int2BinaryDumper
+ _int4_dumper = Int4BinaryDumper
+ _int8_dumper = Int8BinaryDumper
+ _int_numeric_dumper = IntNumericBinaryDumper
+
+
+@cython.final
+cdef class IntLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ # if the number ends with a 0 we don't need a copy
+ if data[length] == b'\0':
+ return PyLong_FromString(data, NULL, 10)
+
+ # Otherwise we have to copy it aside
+ if length > MAXINT8LEN:
+ raise ValueError("string too big for an int")
+
+ cdef char[MAXINT8LEN + 1] buf
+ memcpy(buf, data, length)
+ buf[length] = 0
+ return PyLong_FromString(buf, NULL, 10)
+
+
+
+@cython.final
+cdef class Int2BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLong(<int16_t>endian.be16toh((<uint16_t *>data)[0]))
+
+
+@cython.final
+cdef class Int4BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLong(<int32_t>endian.be32toh((<uint32_t *>data)[0]))
+
+
+@cython.final
+cdef class Int8BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromLongLong(<int64_t>endian.be64toh((<uint64_t *>data)[0]))
+
+
+@cython.final
+cdef class OidBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return PyLong_FromUnsignedLong(endian.be32toh((<uint32_t *>data)[0]))
+
+
+cdef class _FloatDumper(CDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef double d = PyFloat_AsDouble(obj)
+ cdef char *out = PyOS_double_to_string(
+ d, b'r', 0, Py_DTSF_ADD_DOT_0, NULL)
+ cdef Py_ssize_t length = strlen(out)
+ cdef char *tgt = CDumper.ensure_size(rv, offset, length)
+ memcpy(tgt, out, length)
+ PyMem_Free(out)
+ return length
+
+ def quote(self, obj) -> bytes:
+ value = bytes(self.dump(obj))
+ cdef PyObject *ptr = PyDict_GetItem(_special_float, value)
+ if ptr != NULL:
+ return <object>ptr
+
+ return value if obj >= 0 else b" " + value
+
+cdef dict _special_float = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+}
+
+
+@cython.final
+cdef class FloatDumper(_FloatDumper):
+
+ oid = oids.FLOAT8_OID
+
+
+@cython.final
+cdef class Float4Dumper(_FloatDumper):
+
+ oid = oids.FLOAT4_OID
+
+
+@cython.final
+cdef class FloatBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.FLOAT8_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef double d = PyFloat_AsDouble(obj)
+ cdef uint64_t *intptr = <uint64_t *>&d
+ cdef uint64_t *buf = <uint64_t *>CDumper.ensure_size(
+ rv, offset, sizeof(uint64_t))
+ buf[0] = endian.htobe64(intptr[0])
+ return sizeof(uint64_t)
+
+
+@cython.final
+cdef class Float4BinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.FLOAT4_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef float f = <float>PyFloat_AsDouble(obj)
+ cdef uint32_t *intptr = <uint32_t *>&f
+ cdef uint32_t *buf = <uint32_t *>CDumper.ensure_size(
+ rv, offset, sizeof(uint32_t))
+ buf[0] = endian.htobe32(intptr[0])
+ return sizeof(uint32_t)
+
+
+@cython.final
+cdef class FloatLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef char *endptr
+ cdef double d = PyOS_string_to_double(
+ data, &endptr, <PyObject *>OverflowError)
+ return PyFloat_FromDouble(d)
+
+
+@cython.final
+cdef class Float4BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef uint32_t asint = endian.be32toh((<uint32_t *>data)[0])
+ # avoid warning:
+ # dereferencing type-punned pointer will break strict-aliasing rules
+ cdef char *swp = <char *>&asint
+ return PyFloat_FromDouble((<float *>swp)[0])
+
+
+@cython.final
+cdef class Float8BinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef uint64_t asint = endian.be64toh((<uint64_t *>data)[0])
+ cdef char *swp = <char *>&asint
+ return PyFloat_FromDouble((<double *>swp)[0])
+
+
+@cython.final
+cdef class DecimalDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_decimal_to_text(obj, rv, offset)
+
+ def quote(self, obj) -> bytes:
+ value = bytes(self.dump(obj))
+ cdef PyObject *ptr = PyDict_GetItem(_special_decimal, value)
+ if ptr != NULL:
+ return <object>ptr
+
+ return value if obj >= 0 else b" " + value
+
+cdef dict _special_decimal = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+}
+
+
+@cython.final
+cdef class NumericLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ s = PyUnicode_DecodeUTF8(<char *>data, length, NULL)
+ return Decimal(s)
+
+
+cdef dict _decimal_special = {
+ NUMERIC_NAN: Decimal("NaN"),
+ NUMERIC_PINF: Decimal("Infinity"),
+ NUMERIC_NINF: Decimal("-Infinity"),
+}
+
+cdef dict _contexts = {}
+for _i in range(DefaultContext.prec):
+ _contexts[_i] = DefaultContext
+
+
+@cython.final
+cdef class NumericBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+
+ cdef uint16_t *data16 = <uint16_t *>data
+ cdef uint16_t ndigits = endian.be16toh(data16[0])
+ cdef int16_t weight = <int16_t>endian.be16toh(data16[1])
+ cdef uint16_t sign = endian.be16toh(data16[2])
+ cdef uint16_t dscale = endian.be16toh(data16[3])
+ cdef int shift
+ cdef int i
+ cdef PyObject *pctx
+ cdef object key
+
+ if sign == NUMERIC_POS or sign == NUMERIC_NEG:
+ if length != (4 + ndigits) * sizeof(uint16_t):
+ raise e.DataError("bad ndigits in numeric binary representation")
+
+ val = 0
+ for i in range(ndigits):
+ val *= 10_000
+ val += endian.be16toh(data16[i + 4])
+
+ shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
+
+ key = (weight + 2) * DEC_DIGITS + dscale
+ pctx = PyDict_GetItem(_contexts, key)
+ if pctx == NULL:
+ ctx = Context(prec=key)
+ PyDict_SetItem(_contexts, key, ctx)
+ pctx = <PyObject *>ctx
+
+ return (
+ Decimal(val if sign == NUMERIC_POS else -val)
+ .scaleb(-dscale, <object>pctx)
+ .shift(shift, <object>pctx)
+ )
+ else:
+ try:
+ return _decimal_special[sign]
+ except KeyError:
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}")
+
+
+@cython.final
+cdef class DecimalBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ return dump_decimal_to_numeric_binary(obj, rv, offset)
+
+
+@cython.final
+cdef class NumericDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ if isinstance(obj, int):
+ return dump_int_to_text(obj, rv, offset)
+ else:
+ return dump_decimal_to_text(obj, rv, offset)
+
+
+@cython.final
+cdef class NumericBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.NUMERIC_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ if isinstance(obj, int):
+ return dump_int_to_numeric_binary(obj, rv, offset)
+ else:
+ return dump_decimal_to_numeric_binary(obj, rv, offset)
+
+
+cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *src
+ cdef Py_ssize_t length
+ cdef char *buf
+
+ b = bytes(str(obj), "utf-8")
+ PyBytes_AsStringAndSize(b, &src, &length)
+
+ if src[0] != b's':
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, src, length)
+
+ else: # convert sNaN to NaN
+ length = 3 # NaN
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, b"NaN", length)
+
+ return length
+
+
+cdef extern from *:
+ """
+/* Weights of py digits into a pg digit according to their positions. */
+static const int pydigit_weights[] = {1000, 100, 10, 1};
+"""
+ const int[4] pydigit_weights
+
+
+@cython.cdivision(True)
+cdef Py_ssize_t dump_decimal_to_numeric_binary(
+ obj, bytearray rv, Py_ssize_t offset
+) except -1:
+
+ # TODO: this implementation is about 30% slower than the text dump.
+ # This might be probably optimised by accessing the C structure of
+ # the Decimal object, if available, which would save the creation of
+ # several intermediate Python objects (the DecimalTuple, the digits
+ # tuple, and then accessing them).
+
+ cdef object t = obj.as_tuple()
+ cdef int sign = t[0]
+ cdef tuple digits = t[1]
+ cdef uint16_t *buf
+ cdef Py_ssize_t length
+
+ cdef object pyexp = t[2]
+ cdef const char *bexp
+ if not isinstance(pyexp, int):
+ # Handle inf, nan
+ length = 4 * sizeof(uint16_t)
+ buf = <uint16_t *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = 0
+ buf[1] = 0
+ buf[3] = 0
+ bexp = PyUnicode_AsUTF8(pyexp)
+ if bexp[0] == b'n' or bexp[0] == b'N':
+ buf[2] = endian.htobe16(NUMERIC_NAN)
+ elif bexp[0] == b'F':
+ if sign:
+ buf[2] = endian.htobe16(NUMERIC_NINF)
+ else:
+ buf[2] = endian.htobe16(NUMERIC_PINF)
+ else:
+ raise e.DataError(f"unexpected decimal exponent: {pyexp}")
+ return length
+
+ cdef int exp = pyexp
+ cdef uint16_t ndigits = <uint16_t>len(digits)
+
+ # Find the last nonzero digit
+ cdef int nzdigits = ndigits
+ while nzdigits > 0 and digits[nzdigits - 1] == 0:
+ nzdigits -= 1
+
+ cdef uint16_t dscale
+ if exp <= 0:
+ dscale = -exp
+ else:
+ dscale = 0
+ # align the py digits to the pg digits if there's some py exponent
+ ndigits += exp % DEC_DIGITS
+
+ if nzdigits == 0:
+ length = 4 * sizeof(uint16_t)
+ buf = <uint16_t *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = 0 # ndigits
+ buf[1] = 0 # weight
+ buf[2] = endian.htobe16(NUMERIC_POS) # sign
+ buf[3] = endian.htobe16(dscale)
+ return length
+
+ # Equivalent of 0-padding left to align the py digits to the pg digits
+ # but without changing the digits tuple.
+ cdef int wi = 0
+ cdef int mod = (ndigits - dscale) % DEC_DIGITS
+ if mod < 0:
+ # the difference between C and Py % operator
+ mod += 4
+ if mod:
+ wi = DEC_DIGITS - mod
+ ndigits += wi
+
+ cdef int tmp = nzdigits + wi
+ cdef int pgdigits = tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1)
+ length = (pgdigits + 4) * sizeof(uint16_t)
+ buf = <uint16_t*>CDumper.ensure_size(rv, offset, length)
+ buf[0] = endian.htobe16(pgdigits)
+ buf[1] = endian.htobe16(<int16_t>((ndigits + exp) // DEC_DIGITS - 1))
+ buf[2] = endian.htobe16(NUMERIC_NEG) if sign else endian.htobe16(NUMERIC_POS)
+ buf[3] = endian.htobe16(dscale)
+
+ cdef uint16_t pgdigit = 0
+ cdef int bi = 4
+ for i in range(nzdigits):
+ pgdigit += pydigit_weights[wi] * <int>(digits[i])
+ wi += 1
+ if wi >= DEC_DIGITS:
+ buf[bi] = endian.htobe16(pgdigit)
+ pgdigit = wi = 0
+ bi += 1
+
+ if pgdigit:
+ buf[bi] = endian.htobe16(pgdigit)
+
+ return length
+
+
+cdef Py_ssize_t dump_int_to_text(obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef long long val
+ cdef int overflow
+ cdef char *buf
+ cdef char *src
+ cdef Py_ssize_t length
+
+ # Ensure an int or a subclass. The 'is' type check is fast.
+ # Passing a float must give an error, but passing an Enum should work.
+ if type(obj) is not int and not isinstance(obj, int):
+ raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
+
+ val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+ if not overflow:
+ buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1)
+ length = pg_lltoa(val, buf)
+ else:
+ b = bytes(str(obj), "utf-8")
+ PyBytes_AsStringAndSize(b, &src, &length)
+ buf = CDumper.ensure_size(rv, offset, length)
+ memcpy(buf, src, length)
+
+ return length
+
+
+cdef Py_ssize_t dump_int_to_numeric_binary(obj, bytearray rv, Py_ssize_t offset) except -1:
+ # Calculate the number of PG digits required to store the number
+ cdef uint16_t ndigits
+ ndigits = <uint16_t>((<int>obj.bit_length()) * BIT_PER_PGDIGIT) + 1
+
+ cdef uint16_t sign = NUMERIC_POS
+ if obj < 0:
+ sign = NUMERIC_NEG
+ obj = -obj
+
+ cdef Py_ssize_t length = sizeof(uint16_t) * (ndigits + 4)
+ cdef uint16_t *buf
+ buf = <uint16_t *><void *>CDumper.ensure_size(rv, offset, length)
+ buf[0] = endian.htobe16(ndigits)
+ buf[1] = endian.htobe16(ndigits - 1) # weight
+ buf[2] = endian.htobe16(sign)
+ buf[3] = 0 # dscale
+
+ cdef int i = 4 + ndigits - 1
+ cdef uint16_t rem
+ while obj:
+ rem = obj % 10000
+ obj //= 10000
+ buf[i] = endian.htobe16(rem)
+ i -= 1
+ while i > 3:
+ buf[i] = 0
+ i -= 1
+
+ return length
diff --git a/psycopg_c/psycopg_c/types/numutils.c b/psycopg_c/psycopg_c/types/numutils.c
new file mode 100644
index 0000000..4be7108
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/numutils.c
@@ -0,0 +1,243 @@
+/*
+ * Utilities to deal with numbers.
+ *
+ * Copyright (C) 2020 The Psycopg Team
+ * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ */
+
+#include <stdint.h>
+#include <string.h>
+
+#include "pg_config.h"
+
+
+/*
+ * 64-bit integers
+ */
+#ifdef HAVE_LONG_INT_64
+/* Plain "long int" fits, use it */
+
+# ifndef HAVE_INT64
+typedef long int int64;
+# endif
+# ifndef HAVE_UINT64
+typedef unsigned long int uint64;
+# endif
+# define INT64CONST(x) (x##L)
+# define UINT64CONST(x) (x##UL)
+#elif defined(HAVE_LONG_LONG_INT_64)
+/* We have working support for "long long int", use that */
+
+# ifndef HAVE_INT64
+typedef long long int int64;
+# endif
+# ifndef HAVE_UINT64
+typedef unsigned long long int uint64;
+# endif
+# define INT64CONST(x) (x##LL)
+# define UINT64CONST(x) (x##ULL)
+#else
+/* neither HAVE_LONG_INT_64 nor HAVE_LONG_LONG_INT_64 */
+# error must have a working 64-bit integer datatype
+#endif
+
+
+#ifndef HAVE__BUILTIN_CLZ
+static const uint8_t pg_leftmost_one_pos[256] = {
+ 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
+};
+#endif
+
+static const char DIGIT_TABLE[200] = {
+ '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0',
+ '7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4',
+ '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2',
+ '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
+ '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3',
+ '7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4',
+ '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5',
+ '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
+ '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6',
+ '7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4',
+ '7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8',
+ '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
+ '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9',
+ '7', '9', '8', '9', '9'
+};
+
+
+/*
+ * pg_leftmost_one_pos64
+ * As above, but for a 64-bit word.
+ */
+static inline int
+pg_leftmost_one_pos64(uint64_t word)
+{
+#ifdef HAVE__BUILTIN_CLZ
+#if defined(HAVE_LONG_INT_64)
+ return 63 - __builtin_clzl(word);
+#elif defined(HAVE_LONG_LONG_INT_64)
+ return 63 - __builtin_clzll(word);
+#else
+#error must have a working 64-bit integer datatype
+#endif
+#else /* !HAVE__BUILTIN_CLZ */
+ int shift = 64 - 8;
+
+ while ((word >> shift) == 0)
+ shift -= 8;
+
+ return shift + pg_leftmost_one_pos[(word >> shift) & 255];
+#endif /* HAVE__BUILTIN_CLZ */
+}
+
+
+static inline int
+decimalLength64(const uint64_t v)
+{
+ int t;
+ static const uint64_t PowersOfTen[] = {
+ UINT64CONST(1), UINT64CONST(10),
+ UINT64CONST(100), UINT64CONST(1000),
+ UINT64CONST(10000), UINT64CONST(100000),
+ UINT64CONST(1000000), UINT64CONST(10000000),
+ UINT64CONST(100000000), UINT64CONST(1000000000),
+ UINT64CONST(10000000000), UINT64CONST(100000000000),
+ UINT64CONST(1000000000000), UINT64CONST(10000000000000),
+ UINT64CONST(100000000000000), UINT64CONST(1000000000000000),
+ UINT64CONST(10000000000000000), UINT64CONST(100000000000000000),
+ UINT64CONST(1000000000000000000), UINT64CONST(10000000000000000000)
+ };
+
+ /*
+ * Compute base-10 logarithm by dividing the base-2 logarithm by a
+ * good-enough approximation of the base-2 logarithm of 10
+ */
+ t = (pg_leftmost_one_pos64(v) + 1) * 1233 / 4096;
+ return t + (v >= PowersOfTen[t]);
+}
+
+
+/*
+ * Get the decimal representation, not NUL-terminated, and return the length of
+ * same. Caller must ensure that a points to at least MAXINT8LEN bytes.
+ */
+int
+pg_ulltoa_n(uint64_t value, char *a)
+{
+ int olength,
+ i = 0;
+ uint32_t value2;
+
+ /* Degenerate case */
+ if (value == 0)
+ {
+ *a = '0';
+ return 1;
+ }
+
+ olength = decimalLength64(value);
+
+ /* Compute the result string. */
+ while (value >= 100000000)
+ {
+ const uint64_t q = value / 100000000;
+ uint32_t value2 = (uint32_t) (value - 100000000 * q);
+
+ const uint32_t c = value2 % 10000;
+ const uint32_t d = value2 / 10000;
+ const uint32_t c0 = (c % 100) << 1;
+ const uint32_t c1 = (c / 100) << 1;
+ const uint32_t d0 = (d % 100) << 1;
+ const uint32_t d1 = (d / 100) << 1;
+
+ char *pos = a + olength - i;
+
+ value = q;
+
+ memcpy(pos - 2, DIGIT_TABLE + c0, 2);
+ memcpy(pos - 4, DIGIT_TABLE + c1, 2);
+ memcpy(pos - 6, DIGIT_TABLE + d0, 2);
+ memcpy(pos - 8, DIGIT_TABLE + d1, 2);
+ i += 8;
+ }
+
+ /* Switch to 32-bit for speed */
+ value2 = (uint32_t) value;
+
+ if (value2 >= 10000)
+ {
+ const uint32_t c = value2 - 10000 * (value2 / 10000);
+ const uint32_t c0 = (c % 100) << 1;
+ const uint32_t c1 = (c / 100) << 1;
+
+ char *pos = a + olength - i;
+
+ value2 /= 10000;
+
+ memcpy(pos - 2, DIGIT_TABLE + c0, 2);
+ memcpy(pos - 4, DIGIT_TABLE + c1, 2);
+ i += 4;
+ }
+ if (value2 >= 100)
+ {
+ const uint32_t c = (value2 % 100) << 1;
+ char *pos = a + olength - i;
+
+ value2 /= 100;
+
+ memcpy(pos - 2, DIGIT_TABLE + c, 2);
+ i += 2;
+ }
+ if (value2 >= 10)
+ {
+ const uint32_t c = value2 << 1;
+ char *pos = a + olength - i;
+
+ memcpy(pos - 2, DIGIT_TABLE + c, 2);
+ }
+ else
+ *a = (char) ('0' + value2);
+
+ return olength;
+}
+
+/*
+ * pg_lltoa: converts a signed 64-bit integer to its string representation and
+ * returns strlen(a).
+ *
+ * Caller must ensure that 'a' points to enough memory to hold the result
+ * (at least MAXINT8LEN + 1 bytes, counting a leading sign and trailing NUL).
+ */
+int
+pg_lltoa(int64_t value, char *a)
+{
+ uint64_t uvalue = value;
+ int len = 0;
+
+ if (value < 0)
+ {
+ uvalue = (uint64_t) 0 - uvalue;
+ a[len++] = '-';
+ }
+
+ len += pg_ulltoa_n(uvalue, a + len);
+ a[len] = '\0';
+ return len;
+}
diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx
new file mode 100644
index 0000000..da18b01
--- /dev/null
+++ b/psycopg_c/psycopg_c/types/string.pyx
@@ -0,0 +1,315 @@
+"""
+Cython adapters for textual types.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+cimport cython
+
+from libc.string cimport memcpy, memchr
+from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
+from cpython.unicode cimport (
+ PyUnicode_AsEncodedString,
+ PyUnicode_AsUTF8String,
+ PyUnicode_CheckExact,
+ PyUnicode_Decode,
+ PyUnicode_DecodeUTF8,
+)
+
+from psycopg_c.pq cimport libpq, Escaping, _buffer_as_string_and_size
+
+from psycopg import errors as e
+from psycopg._encodings import pg2pyenc
+
+cdef extern from "Python.h":
+ const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL
+
+
+cdef class _BaseStrDumper(CDumper):
+ cdef int is_utf8
+ cdef char *encoding
+ cdef bytes _bytes_encoding # needed to keep `encoding` alive
+
+ def __cinit__(self, cls, context: Optional[AdaptContext] = None):
+
+ self.is_utf8 = 0
+ self.encoding = "utf-8"
+ cdef const char *pgenc
+
+ if self._pgconn is not None:
+ pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding")
+ if pgenc == NULL or pgenc == b"UTF8":
+ self._bytes_encoding = b"utf-8"
+ self.is_utf8 = 1
+ else:
+ self._bytes_encoding = pg2pyenc(pgenc).encode()
+ if self._bytes_encoding == b"ascii":
+ self.is_utf8 = 1
+ self.encoding = PyBytes_AsString(self._bytes_encoding)
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ # the server will raise DataError subclass if the string contains 0x00
+ cdef Py_ssize_t size;
+ cdef const char *src
+
+ if self.is_utf8:
+ # Probably the fastest path, but doesn't work with subclasses
+ if PyUnicode_CheckExact(obj):
+ src = PyUnicode_AsUTF8AndSize(obj, &size)
+ else:
+ b = PyUnicode_AsUTF8String(obj)
+ PyBytes_AsStringAndSize(b, <char **>&src, &size)
+ else:
+ b = PyUnicode_AsEncodedString(obj, self.encoding, NULL)
+ PyBytes_AsStringAndSize(b, <char **>&src, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+cdef class _StrBinaryDumper(_BaseStrDumper):
+
+ format = PQ_BINARY
+
+
+@cython.final
+cdef class StrBinaryDumper(_StrBinaryDumper):
+
+ oid = oids.TEXT_OID
+
+
+@cython.final
+cdef class StrBinaryDumperVarchar(_StrBinaryDumper):
+
+ oid = oids.VARCHAR_OID
+
+
+@cython.final
+cdef class StrBinaryDumperName(_StrBinaryDumper):
+
+ oid = oids.NAME_OID
+
+
+cdef class _StrDumper(_BaseStrDumper):
+
+ format = PQ_TEXT
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef Py_ssize_t size = StrBinaryDumper.cdump(self, obj, rv, offset)
+
+ # Like the binary dump, but check for 0, or the string will be truncated
+ cdef const char *buf = PyByteArray_AS_STRING(rv)
+ if NULL != memchr(buf + offset, 0x00, size):
+ raise e.DataError(
+ "PostgreSQL text fields cannot contain NUL (0x00) bytes"
+ )
+ return size
+
+
+@cython.final
+cdef class StrDumper(_StrDumper):
+
+ oid = oids.TEXT_OID
+
+
+@cython.final
+cdef class StrDumperVarchar(_StrDumper):
+
+ oid = oids.VARCHAR_OID
+
+
+@cython.final
+cdef class StrDumperName(_StrDumper):
+
+ oid = oids.NAME_OID
+
+
+@cython.final
+cdef class StrDumperUnknown(_StrDumper):
+ pass
+
+
+cdef class _TextLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef int is_utf8
+ cdef char *encoding
+ cdef bytes _bytes_encoding # needed to keep `encoding` alive
+
+ def __cinit__(self, oid: int, context: Optional[AdaptContext] = None):
+
+ self.is_utf8 = 0
+ self.encoding = "utf-8"
+ cdef const char *pgenc
+
+ if self._pgconn is not None:
+ pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding")
+ if pgenc == NULL or pgenc == b"UTF8":
+ self._bytes_encoding = b"utf-8"
+ self.is_utf8 = 1
+ else:
+ self._bytes_encoding = pg2pyenc(pgenc).encode()
+
+ if pgenc == b"SQL_ASCII":
+ self.encoding = NULL
+ else:
+ self.encoding = PyBytes_AsString(self._bytes_encoding)
+
+ cdef object cload(self, const char *data, size_t length):
+ if self.is_utf8:
+ return PyUnicode_DecodeUTF8(<char *>data, length, NULL)
+ elif self.encoding:
+ return PyUnicode_Decode(<char *>data, length, self.encoding, NULL)
+ else:
+ return data[:length]
+
+@cython.final
+cdef class TextLoader(_TextLoader):
+
+ format = PQ_TEXT
+
+
+@cython.final
+cdef class TextBinaryLoader(_TextLoader):
+
+ format = PQ_BINARY
+
+
+@cython.final
+cdef class BytesDumper(CDumper):
+
+ format = PQ_TEXT
+ oid = oids.BYTEA_OID
+
+ # 0: not set, 1: just single "'" quote, 3: " E'" quote
+ cdef int _qplen
+
+ def __cinit__(self):
+ self._qplen = 0
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+
+ _buffer_as_string_and_size(obj, &ptr, &length)
+
+ if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+ out = libpq.PQescapeByteaConn(
+ self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_out)
+ else:
+ out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out)
+
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {length} bytes"
+ )
+
+ len_out -= 1 # out includes final 0
+ cdef char *buf = CDumper.ensure_size(rv, offset, len_out)
+ memcpy(buf, out, len_out)
+ libpq.PQfreemem(out)
+ return len_out
+
+ def quote(self, obj):
+ cdef size_t len_out
+ cdef unsigned char *out
+ cdef char *ptr
+ cdef Py_ssize_t length
+ cdef const char *scs
+
+ escaped = self.dump(obj)
+ _buffer_as_string_and_size(escaped, &ptr, &length)
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self._pgconn is not None:
+ if not self._qplen:
+ scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr,
+ b"standard_conforming_strings")
+ if scs and scs[0] == b'o' and scs[1] == b"n": # == "on"
+ self._qplen = 1
+ else:
+ self._qplen = 3
+
+ PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ if self._qplen == 1:
+ ptr_out[0] = b"'"
+ else:
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + self._qplen, ptr, length)
+ ptr_out[length + self._qplen] = b"'"
+ return rv
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ PyByteArray_Resize(rv, length + 4) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ esc = Escaping()
+ if esc.escape_bytea(b"\x00") == b"\\000":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
+
+ return rv
+
+
+@cython.final
+cdef class BytesBinaryDumper(CDumper):
+
+ format = PQ_BINARY
+ oid = oids.BYTEA_OID
+
+ cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+ cdef char *src
+ cdef Py_ssize_t size;
+ _buffer_as_string_and_size(obj, &src, &size)
+
+ cdef char *buf = CDumper.ensure_size(rv, offset, size)
+ memcpy(buf, src, size)
+ return size
+
+
+@cython.final
+cdef class ByteaLoader(CLoader):
+
+ format = PQ_TEXT
+
+ cdef object cload(self, const char *data, size_t length):
+ cdef size_t len_out
+ cdef unsigned char *out = libpq.PQunescapeBytea(
+ <const unsigned char *>data, &len_out)
+ if out is NULL:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = out[:len_out]
+ libpq.PQfreemem(out)
+ return rv
+
+
+@cython.final
+cdef class ByteaBinaryLoader(CLoader):
+
+ format = PQ_BINARY
+
+ cdef object cload(self, const char *data, size_t length):
+ return data[:length]
diff --git a/psycopg_c/psycopg_c/version.py b/psycopg_c/psycopg_c/version.py
new file mode 100644
index 0000000..5c989c2
--- /dev/null
+++ b/psycopg_c/psycopg_c/version.py
@@ -0,0 +1,11 @@
+"""
+psycopg-c distribution version file.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+__version__ = "3.1.7"
+
+# also change psycopg/psycopg/version.py accordingly.
diff --git a/psycopg_c/pyproject.toml b/psycopg_c/pyproject.toml
new file mode 100644
index 0000000..f0d7a3f
--- /dev/null
+++ b/psycopg_c/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0.0a11"]
+build-backend = "setuptools.build_meta"
diff --git a/psycopg_c/setup.cfg b/psycopg_c/setup.cfg
new file mode 100644
index 0000000..6c5c93c
--- /dev/null
+++ b/psycopg_c/setup.cfg
@@ -0,0 +1,57 @@
+[metadata]
+name = psycopg-c
+description = PostgreSQL database adapter for Python -- C optimisation distribution
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg-c/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+setup_requires = Cython >= 3.0.0a11
+packages = find:
+zip_safe = False
+
+[options.package_data]
+# NOTE: do not include .pyx files: they shouldn't be in the sdist
+# package, so that build is only performed from the .c files (which are
+# distributed instead).
+psycopg_c =
+ py.typed
+ *.pyi
+ *.pxd
+ _psycopg/*.pxd
+ pq/*.pxd
+
+# In the psycopg-binary distribution don't include cython-related files.
+psycopg_binary =
+ py.typed
+ *.pyi
diff --git a/psycopg_c/setup.py b/psycopg_c/setup.py
new file mode 100644
index 0000000..c6da3a1
--- /dev/null
+++ b/psycopg_c/setup.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - optimisation package
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+import sys
+import subprocess as sp
+
+from setuptools import setup, Extension
+from distutils.command.build_ext import build_ext
+from distutils import log
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+with open("psycopg_c/version.py") as f:
+ data = f.read()
+ m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
+ if m is None:
+ raise Exception(f"cannot find version in {f.name}")
+ version = m.group(1)
+
+
+def get_config(what: str) -> str:
+ pg_config = "pg_config"
+ try:
+ out = sp.run([pg_config, f"--{what}"], stdout=sp.PIPE, check=True)
+ except Exception as e:
+ log.error(f"couldn't run {pg_config!r} --{what}: %s", e)
+ raise
+ else:
+ return out.stdout.strip().decode()
+
+
+class psycopg_build_ext(build_ext):
+ def finalize_options(self) -> None:
+ self._setup_ext_build()
+ super().finalize_options()
+
+ def _setup_ext_build(self) -> None:
+ cythonize = None
+
+ # In the sdist there are not .pyx, only c, so we don't need Cython.
+ # Otherwise Cython is a requirement and it is used to compile pyx to c.
+ if os.path.exists("psycopg_c/_psycopg.pyx"):
+ from Cython.Build import cythonize
+
+ # Add include and lib dir for the libpq.
+ includedir = get_config("includedir")
+ libdir = get_config("libdir")
+ for ext in self.distribution.ext_modules:
+ ext.include_dirs.append(includedir)
+ ext.library_dirs.append(libdir)
+
+ if sys.platform == "win32":
+ # For __imp_htons and others
+ ext.libraries.append("ws2_32")
+
+ if cythonize is not None:
+ for ext in self.distribution.ext_modules:
+ for i in range(len(ext.sources)):
+ base, fext = os.path.splitext(ext.sources[i])
+ if fext == ".c" and os.path.exists(base + ".pyx"):
+ ext.sources[i] = base + ".pyx"
+
+ self.distribution.ext_modules = cythonize(
+ self.distribution.ext_modules,
+ language_level=3,
+ compiler_directives={
+ "always_allow_keywords": False,
+ },
+ annotate=False, # enable to get an html view of the C module
+ )
+ else:
+ self.distribution.ext_modules = [pgext, pqext]
+
+
+# MSVC requires an explicit "libpq"
+libpq = "pq" if sys.platform != "win32" else "libpq"
+
+# Some details missing, to be finished by psycopg_build_ext.finalize_options
+pgext = Extension(
+ "psycopg_c._psycopg",
+ [
+ "psycopg_c/_psycopg.c",
+ "psycopg_c/types/numutils.c",
+ ],
+ libraries=[libpq],
+ include_dirs=[],
+)
+
+pqext = Extension(
+ "psycopg_c.pq",
+ ["psycopg_c/pq.c"],
+ libraries=[libpq],
+ include_dirs=[],
+)
+
+setup(
+ version=version,
+ ext_modules=[pgext, pqext],
+ cmdclass={"build_ext": psycopg_build_ext},
+)
diff --git a/psycopg_pool/.flake8 b/psycopg_pool/.flake8
new file mode 100644
index 0000000..2ae629c
--- /dev/null
+++ b/psycopg_pool/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 88
+ignore = W503, E203
diff --git a/psycopg_pool/LICENSE.txt b/psycopg_pool/LICENSE.txt
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/psycopg_pool/LICENSE.txt
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/psycopg_pool/README.rst b/psycopg_pool/README.rst
new file mode 100644
index 0000000..6e6b32c
--- /dev/null
+++ b/psycopg_pool/README.rst
@@ -0,0 +1,24 @@
+Psycopg 3: PostgreSQL database adapter for Python - Connection Pool
+===================================================================
+
+This distribution contains the optional connection pool package
+`psycopg_pool`__.
+
+.. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html
+
+This package is kept separate from the main ``psycopg`` package because it is
+likely that it will follow a different release cycle.
+
+You can also install this package using::
+
+ pip install "psycopg[pool]"
+
+Please read `the project readme`__ and `the installation documentation`__ for
+more details.
+
+.. __: https://github.com/psycopg/psycopg#readme
+.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html
+ #installing-the-connection-pool
+
+
+Copyright (C) 2020 The Psycopg Team
diff --git a/psycopg_pool/psycopg_pool/__init__.py b/psycopg_pool/psycopg_pool/__init__.py
new file mode 100644
index 0000000..e4d975f
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/__init__.py
@@ -0,0 +1,22 @@
+"""
+psycopg connection pool package
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from .pool import ConnectionPool
+from .pool_async import AsyncConnectionPool
+from .null_pool import NullConnectionPool
+from .null_pool_async import AsyncNullConnectionPool
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from .version import __version__ as __version__ # noqa: F401
+
+__all__ = [
+ "AsyncConnectionPool",
+ "AsyncNullConnectionPool",
+ "ConnectionPool",
+ "NullConnectionPool",
+ "PoolClosed",
+ "PoolTimeout",
+ "TooManyRequests",
+]
diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py
new file mode 100644
index 0000000..9fb2b9b
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/_compat.py
@@ -0,0 +1,51 @@
+"""
+compatibility functions for different Python versions
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import sys
+import asyncio
+from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar
+from typing_extensions import TypeAlias
+
+import psycopg.errors as e
+
+T = TypeVar("T")
+FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
+
+if sys.version_info >= (3, 8):
+ create_task = asyncio.create_task
+ Task = asyncio.Task
+
+else:
+
+ def create_task(
+ coro: FutureT[T], name: Optional[str] = None
+ ) -> "asyncio.Future[T]":
+ return asyncio.create_task(coro)
+
+ Task = asyncio.Future
+
+if sys.version_info >= (3, 9):
+ from collections import Counter, deque as Deque
+else:
+ from typing import Counter, Deque
+
+__all__ = [
+ "Counter",
+ "Deque",
+ "Task",
+ "create_task",
+]
+
+# Workaround for psycopg < 3.0.8.
+# Timeout on NullPool connection mignt not work correctly.
+try:
+ ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout
+except AttributeError:
+
+ class DummyConnectionTimeout(e.OperationalError):
+ pass
+
+ ConnectionTimeout = DummyConnectionTimeout
diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py
new file mode 100644
index 0000000..298ea68
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/base.py
@@ -0,0 +1,230 @@
+"""
+psycopg connection pool base class and functionalities.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from time import monotonic
+from random import random
+from typing import Any, Callable, Dict, Generic, Optional, Tuple
+
+from psycopg import errors as e
+from psycopg.abc import ConnectionType
+
+from .errors import PoolClosed
+from ._compat import Counter, Deque
+
+
+class BasePool(Generic[ConnectionType]):
+
+ # Used to generate pool names
+ _num_pool = 0
+
+ # Stats keys
+ _POOL_MIN = "pool_min"
+ _POOL_MAX = "pool_max"
+ _POOL_SIZE = "pool_size"
+ _POOL_AVAILABLE = "pool_available"
+ _REQUESTS_WAITING = "requests_waiting"
+ _REQUESTS_NUM = "requests_num"
+ _REQUESTS_QUEUED = "requests_queued"
+ _REQUESTS_WAIT_MS = "requests_wait_ms"
+ _REQUESTS_ERRORS = "requests_errors"
+ _USAGE_MS = "usage_ms"
+ _RETURNS_BAD = "returns_bad"
+ _CONNECTIONS_NUM = "connections_num"
+ _CONNECTIONS_MS = "connections_ms"
+ _CONNECTIONS_ERRORS = "connections_errors"
+ _CONNECTIONS_LOST = "connections_lost"
+
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ kwargs: Optional[Dict[str, Any]] = None,
+ min_size: int = 4,
+ max_size: Optional[int] = None,
+ open: bool = True,
+ name: Optional[str] = None,
+ timeout: float = 30.0,
+ max_waiting: int = 0,
+ max_lifetime: float = 60 * 60.0,
+ max_idle: float = 10 * 60.0,
+ reconnect_timeout: float = 5 * 60.0,
+ reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None,
+ num_workers: int = 3,
+ ):
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ if not name:
+ num = BasePool._num_pool = BasePool._num_pool + 1
+ name = f"pool-{num}"
+
+ if num_workers < 1:
+ raise ValueError("num_workers must be at least 1")
+
+ self.conninfo = conninfo
+ self.kwargs: Dict[str, Any] = kwargs or {}
+ self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
+ self._reconnect_failed = reconnect_failed or (lambda pool: None)
+ self.name = name
+ self._min_size = min_size
+ self._max_size = max_size
+ self.timeout = timeout
+ self.max_waiting = max_waiting
+ self.reconnect_timeout = reconnect_timeout
+ self.max_lifetime = max_lifetime
+ self.max_idle = max_idle
+ self.num_workers = num_workers
+
+ self._nconns = min_size # currently in the pool, out, being prepared
+ self._pool = Deque[ConnectionType]()
+ self._stats = Counter[str]()
+
+ # Min number of connections in the pool in a max_idle unit of time.
+ # It is reset periodically by the ShrinkPool scheduled task.
+ # It is used to shrink back the pool if maxcon > min_size and extra
+ # connections have been acquired, if we notice that in the last
+ # max_idle interval they weren't all used.
+ self._nconns_min = min_size
+
+ # Flag to allow the pool to grow only one connection at time. In case
+ # of spike, if threads are allowed to grow in parallel and connection
+ # time is slow, there won't be any thread available to return the
+ # connections to the pool.
+ self._growing = False
+
+ self._opened = False
+ self._closed = True
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__class__.__module__}.{self.__class__.__name__}"
+ f" {self.name!r} at 0x{id(self):x}>"
+ )
+
+ @property
+ def min_size(self) -> int:
+ return self._min_size
+
+ @property
+ def max_size(self) -> int:
+ return self._max_size
+
+ @property
+ def closed(self) -> bool:
+ """`!True` if the pool is closed."""
+ return self._closed
+
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
+ if max_size is None:
+ max_size = min_size
+
+ if min_size < 0:
+ raise ValueError("min_size cannot be negative")
+ if max_size < min_size:
+ raise ValueError("max_size must be greater or equal than min_size")
+ if min_size == max_size == 0:
+ raise ValueError("if min_size is 0 max_size must be greater or than 0")
+
+ return min_size, max_size
+
+ def _check_open(self) -> None:
+ if self._closed and self._opened:
+ raise e.OperationalError(
+ "pool has already been opened/closed and cannot be reused"
+ )
+
+ def _check_open_getconn(self) -> None:
+ if self._closed:
+ if self._opened:
+ raise PoolClosed(f"the pool {self.name!r} is already closed")
+ else:
+ raise PoolClosed(f"the pool {self.name!r} is not open yet")
+
+ def _check_pool_putconn(self, conn: ConnectionType) -> None:
+ pool = getattr(conn, "_pool", None)
+ if pool is self:
+ return
+
+ if pool:
+ msg = f"it comes from pool {pool.name!r}"
+ else:
+ msg = "it doesn't come from any pool"
+ raise ValueError(
+ f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+ )
+
+ def get_stats(self) -> Dict[str, int]:
+ """
+ Return current stats about the pool usage.
+ """
+ rv = dict(self._stats)
+ rv.update(self._get_measures())
+ return rv
+
+ def pop_stats(self) -> Dict[str, int]:
+ """
+ Return current stats about the pool usage.
+
+ After the call, all the counters are reset to zero.
+ """
+ stats, self._stats = self._stats, Counter()
+ rv = dict(stats)
+ rv.update(self._get_measures())
+ return rv
+
+ def _get_measures(self) -> Dict[str, int]:
+ """
+ Return immediate measures of the pool (not counters).
+ """
+ return {
+ self._POOL_MIN: self._min_size,
+ self._POOL_MAX: self._max_size,
+ self._POOL_SIZE: self._nconns,
+ self._POOL_AVAILABLE: len(self._pool),
+ }
+
+ @classmethod
+ def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float:
+ """
+ Add a random value to *value* between *min_pc* and *max_pc* percent.
+ """
+ return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
+
+ def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+ """Set an expiry date on a connection.
+
+ Add some randomness to avoid mass reconnection.
+ """
+ conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0)
+
+
+class ConnectionAttempt:
+ """Keep the state of a connection attempt."""
+
+ INITIAL_DELAY = 1.0
+ DELAY_JITTER = 0.1
+ DELAY_BACKOFF = 2.0
+
+ def __init__(self, *, reconnect_timeout: float):
+ self.reconnect_timeout = reconnect_timeout
+ self.delay = 0.0
+ self.give_up_at = 0.0
+
+ def update_delay(self, now: float) -> None:
+ """Calculate how long to wait for a new connection attempt"""
+ if self.delay == 0.0:
+ self.give_up_at = now + self.reconnect_timeout
+ self.delay = BasePool._jitter(
+ self.INITIAL_DELAY, -self.DELAY_JITTER, self.DELAY_JITTER
+ )
+ else:
+ self.delay *= self.DELAY_BACKOFF
+
+ if self.delay + now > self.give_up_at:
+ self.delay = max(0.0, self.give_up_at - now)
+
+ def time_to_give_up(self, now: float) -> bool:
+ """Return True if we are tired of trying to connect. Meh."""
+ return self.give_up_at > 0.0 and now >= self.give_up_at
diff --git a/psycopg_pool/psycopg_pool/errors.py b/psycopg_pool/psycopg_pool/errors.py
new file mode 100644
index 0000000..9e672ad
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/errors.py
@@ -0,0 +1,25 @@
+"""
+Connection pool errors.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from psycopg import errors as e
+
+
+class PoolClosed(e.OperationalError):
+ """Attempt to get a connection from a closed pool."""
+
+ __module__ = "psycopg_pool"
+
+
+class PoolTimeout(e.OperationalError):
+ """The pool couldn't provide a connection in acceptable time."""
+
+ __module__ = "psycopg_pool"
+
+
+class TooManyRequests(e.OperationalError):
+ """Too many requests in the queue waiting for a connection from the pool."""
+
+ __module__ = "psycopg_pool"
diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py
new file mode 100644
index 0000000..c0a77c2
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/null_pool.py
@@ -0,0 +1,159 @@
+"""
+Psycopg null connection pools
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import logging
+import threading
+from typing import Any, Optional, Tuple
+
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .pool import ConnectionPool, AddConnection
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class _BaseNullConnectionPool:
+ def __init__(
+ self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any
+ ):
+ super().__init__( # type: ignore[call-arg]
+ conninfo, *args, min_size=min_size, **kwargs
+ )
+
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
+ if max_size is None:
+ max_size = min_size
+
+ if min_size != 0:
+ raise ValueError("null pools must have min_size = 0")
+ if max_size < min_size:
+ raise ValueError("max_size must be greater or equal than min_size")
+
+ return min_size, max_size
+
+ def _start_initial_tasks(self) -> None:
+ # Null pools don't have background tasks to fill connections
+ # or to grow/shrink.
+ return
+
+ def _maybe_grow_pool(self) -> None:
+ # null pools don't grow
+ pass
+
+
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+ def wait(self, timeout: float = 30.0) -> None:
+ """
+ Create a connection for test.
+
+ Calling this function will verify that the connectivity with the
+ database works as expected. However the connection will not be stored
+ in the pool.
+
+ Close the pool, and raise `PoolTimeout`, if not ready within *timeout*
+ sec.
+ """
+ self._check_open_getconn()
+
+ with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = threading.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ if not self._pool_full_event.wait(timeout):
+ self.close() # stop all the threads
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
+
+ with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[Connection[Any]]:
+ conn: Optional[Connection[Any]] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ with self._lock:
+ if not self._closed and self._waiting:
+ return False
+
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ conn.close()
+ self._nconns -= 1
+ return True
+
+ def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ """Change the size of the pool during runtime.
+
+ Only *max_size* can be changed; *min_size* must remain 0.
+ """
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ def check(self) -> None:
+ """No-op, as the pool doesn't have connections in its state."""
+ pass
+
+ def _add_to_pool(self, conn: Connection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shouldn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py
new file mode 100644
index 0000000..ae9d207
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/null_pool_async.py
@@ -0,0 +1,122 @@
+"""
+psycopg asynchronous null connection pool
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import asyncio
+import logging
+from typing import Any, Optional
+
+from psycopg import AsyncConnection
+from psycopg.pq import TransactionStatus
+
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+from .null_pool import _BaseNullConnectionPool
+from .pool_async import AsyncConnectionPool, AddConnection
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+ async def wait(self, timeout: float = 30.0) -> None:
+ self._check_open_getconn()
+
+ async with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = asyncio.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the tasks
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ ) from None
+
+ async with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ async def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[AsyncConnection[Any]]:
+ conn: Optional[AsyncConnection[Any]] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = await self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ # Close the connection if no client is waiting for it, or if the pool
+ # is closed. For extra refcare remove the pool reference from it.
+ # Maintain the stats.
+ async with self._lock:
+ if not self._closed and self._waiting:
+ return False
+
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ await conn.close()
+ self._nconns -= 1
+ return True
+
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ async with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ async def check(self) -> None:
+ pass
+
+ async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ await conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shouldn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py
new file mode 100644
index 0000000..609d95d
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/pool.py
@@ -0,0 +1,839 @@
+"""
+psycopg synchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+import threading
+from abc import ABC, abstractmethod
+from time import monotonic
+from queue import Queue, Empty
+from types import TracebackType
+from typing import Any, Callable, Dict, Iterator, List
+from typing import Optional, Sequence, Type
+from weakref import ref
+from contextlib import contextmanager
+
+from psycopg import errors as e
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .base import ConnectionAttempt, BasePool
+from .sched import Scheduler
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from ._compat import Deque
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class ConnectionPool(BasePool[Connection[Any]]):
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ open: bool = True,
+ connection_class: Type[Connection[Any]] = Connection,
+ configure: Optional[Callable[[Connection[Any]], None]] = None,
+ reset: Optional[Callable[[Connection[Any]], None]] = None,
+ **kwargs: Any,
+ ):
+ self.connection_class = connection_class
+ self._configure = configure
+ self._reset = reset
+
+ self._lock = threading.RLock()
+ self._waiting = Deque["WaitingClient"]()
+
+ # to notify that the pool is full
+ self._pool_full_event: Optional[threading.Event] = None
+
+ self._sched = Scheduler()
+ self._sched_runner: Optional[threading.Thread] = None
+ self._tasks: "Queue[MaintenanceTask]" = Queue()
+ self._workers: List[threading.Thread] = []
+
+ super().__init__(conninfo, **kwargs)
+
+ if open:
+ self.open()
+
+ def __del__(self) -> None:
+ # If the '_closed' property is not set we probably failed in __init__.
+ # Don't try anything complicated as probably it won't work.
+ if getattr(self, "_closed", True):
+ return
+
+ self._stop_workers()
+
+ def wait(self, timeout: float = 30.0) -> None:
+ """
+ Wait for the pool to be full (with `min_size` connections) after creation.
+
+ Close the pool, and raise `PoolTimeout`, if not ready within *timeout*
+ sec.
+
+ Calling this method is not mandatory: you can try and use the pool
+ immediately after its creation. The first client will be served as soon
+ as a connection is ready. You can use this method if you prefer your
+ program to terminate in case the environment is not configured
+ properly, rather than trying to stay up the hardest it can.
+ """
+ self._check_open_getconn()
+
+ with self._lock:
+ assert not self._pool_full_event
+ if len(self._pool) >= self._min_size:
+ return
+ self._pool_full_event = threading.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ if not self._pool_full_event.wait(timeout):
+ self.close() # stop all the threads
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
+
+ with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ @contextmanager
+ def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
+ """Context manager to obtain a connection from the pool.
+
+ Return the connection immediately if available, otherwise wait up to
+ *timeout* or `self.timeout` seconds and throw `PoolTimeout` if a
+ connection is not available in time.
+
+ Upon context exit, return the connection to the pool. Apply the normal
+ :ref:`connection context behaviour <with-connection>` (commit/rollback
+ the transaction in case of success/error). If the connection is no more
+ in working state, replace it with a new one.
+ """
+ conn = self.getconn(timeout=timeout)
+ t0 = monotonic()
+ try:
+ with conn:
+ yield conn
+ finally:
+ t1 = monotonic()
+ self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
+ self.putconn(conn)
+
+ def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+ """Obtain a connection from the pool.
+
+ You should preferably use `connection()`. Use this function only if
+ it is not possible to use the connection as context manager.
+
+ After using this function you *must* call a corresponding `putconn()`:
+ failing to do so will deplete the pool. A depleted pool is a sad pool:
+ you don't want a depleted pool.
+ """
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ with self._lock:
+ self._check_open_getconn()
+ conn = self._get_ready_connection(timeout)
+ if not conn:
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = WaitingClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If there is space for the pool to grow, let's do it
+ self._maybe_grow_pool()
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if not conn:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ # Note that this property shouldn't be set while the connection is in
+ # the pool, to avoid to create a reference loop.
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[Connection[Any]]:
+ """Return a connection, if the client deserves one."""
+ conn: Optional[Connection[Any]] = None
+ if self._pool:
+ # Take a connection ready out of the pool
+ conn = self._pool.popleft()
+ if len(self._pool) < self._nconns_min:
+ self._nconns_min = len(self._pool)
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_grow_pool(self) -> None:
+ # Allow only one thread at time to grow the pool (or returning
+ # connections might be starved).
+ if self._nconns >= self._max_size or self._growing:
+ return
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self._growing = True
+ self.run_task(AddConnection(self, growing=True))
+
+ def putconn(self, conn: Connection[Any]) -> None:
+ """Return a connection to the loving hands of its pool.
+
+ Use this function only paired with a `getconn()`. You don't need to use
+ it if you use the much more comfortable `connection()` context manager.
+ """
+ # Quick check to discard the wrong connection
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+
+ if self._maybe_close_connection(conn):
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate thread
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ self._return_connection(conn)
+
+ def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ """Close a returned connection if necessary.
+
+ Return `!True if the connection was closed.
+ """
+ # If the pool is closed just close the connection instead of returning
+ # it to the pool. For extra refcare remove the pool reference from it.
+ if not self._closed:
+ return False
+
+ conn._pool = None
+ conn.close()
+ return True
+
+ def open(self, wait: bool = False, timeout: float = 30.0) -> None:
+ """Open the pool by starting connecting and and accepting clients.
+
+ If *wait* is `!False`, return immediately and let the background worker
+ fill the pool if `min_size` > 0. Otherwise wait up to *timeout* seconds
+ for the requested number of connections to be ready (see `wait()` for
+ details).
+
+ It is safe to call `!open()` again on a pool already open (because the
+ method was already called, or because the pool context was entered, or
+ because the pool was initialized with *open* = `!True`) but you cannot
+ currently re-open a closed pool.
+ """
+ with self._lock:
+ self._open()
+
+ if wait:
+ self.wait(timeout=timeout)
+
+ def _open(self) -> None:
+ if not self._closed:
+ return
+
+ self._check_open()
+
+ self._closed = False
+ self._opened = True
+
+ self._start_workers()
+ self._start_initial_tasks()
+
+ def _start_workers(self) -> None:
+ self._sched_runner = threading.Thread(
+ target=self._sched.run,
+ name=f"{self.name}-scheduler",
+ daemon=True,
+ )
+ assert not self._workers
+ for i in range(self.num_workers):
+ t = threading.Thread(
+ target=self.worker,
+ args=(self._tasks,),
+ name=f"{self.name}-worker-{i}",
+ daemon=True,
+ )
+ self._workers.append(t)
+
+ # The object state is complete. Start the worker threads
+ self._sched_runner.start()
+ for t in self._workers:
+ t.start()
+
+ def _start_initial_tasks(self) -> None:
+ # populate the pool with initial min_size connections in background
+ for i in range(self._nconns):
+ self.run_task(AddConnection(self))
+
+ # Schedule a task to shrink the pool if connections over min_size have
+ # remained unused.
+ self.schedule_task(ShrinkPool(self), self.max_idle)
+
+ def close(self, timeout: float = 5.0) -> None:
+ """Close the pool and make it unavailable to new clients.
+
+ All the waiting and future clients will fail to acquire a connection
+ with a `PoolClosed` exception. Currently used connections will not be
+ closed until returned to the pool.
+
+ Wait *timeout* seconds for threads to terminate their job, if positive.
+ If the timeout expires the pool is closed anyway, although it may raise
+ some warnings on exit.
+ """
+ if self._closed:
+ return
+
+ with self._lock:
+ self._closed = True
+ logger.debug("pool %r closed", self.name)
+
+ # Take waiting client and pool connections out of the state
+ waiting = list(self._waiting)
+ self._waiting.clear()
+ connections = list(self._pool)
+ self._pool.clear()
+
+ # Now that the flag _closed is set, getconn will fail immediately,
+ # putconn will just close the returned connection.
+ self._stop_workers(waiting, connections, timeout)
+
+ def _stop_workers(
+ self,
+ waiting_clients: Sequence["WaitingClient"] = (),
+ connections: Sequence[Connection[Any]] = (),
+ timeout: float = 0.0,
+ ) -> None:
+
+ # Stop the scheduler
+ self._sched.enter(0, None)
+
+ # Stop the worker threads
+ workers, self._workers = self._workers[:], []
+ for i in range(len(workers)):
+ self.run_task(StopWorker(self))
+
+ # Signal to eventual clients in the queue that business is closed.
+ for pos in waiting_clients:
+ pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+ # Close the connections still in the pool
+ for conn in connections:
+ conn.close()
+
+ # Wait for the worker threads to terminate
+ assert self._sched_runner is not None
+ sched_runner, self._sched_runner = self._sched_runner, None
+ if timeout > 0:
+ for t in [sched_runner] + workers:
+ if not t.is_alive():
+ continue
+ t.join(timeout)
+ if t.is_alive():
+ logger.warning(
+ "couldn't stop thread %s in pool %r within %s seconds",
+ t,
+ self.name,
+ timeout,
+ )
+
+ def __enter__(self) -> "ConnectionPool":
+ self.open()
+ return self
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.close()
+
+ def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ """Change the size of the pool during runtime."""
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ ngrow = max(0, min_size - self._min_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+ self._nconns += ngrow
+
+ for i in range(ngrow):
+ self.run_task(AddConnection(self))
+
+ def check(self) -> None:
+ """Verify the state of the connections currently in the pool.
+
+ Test each connection: if it works return it to the pool, otherwise
+ dispose of it and create a new one.
+ """
+ with self._lock:
+ conns = list(self._pool)
+ self._pool.clear()
+
+ # Give a chance to the pool to grow if it has no connection.
+ # In case there are enough connection, or the pool is already
+ # growing, this is a no-op.
+ self._maybe_grow_pool()
+
+ while conns:
+ conn = conns.pop()
+ try:
+ conn.execute("SELECT 1")
+ if conn.pgconn.transaction_status == TransactionStatus.INTRANS:
+ conn.rollback()
+ except Exception:
+ self._stats[self._CONNECTIONS_LOST] += 1
+ logger.warning("discarding broken connection: %s", conn)
+ self.run_task(AddConnection(self))
+ else:
+ self._add_to_pool(conn)
+
+ def reconnect_failed(self) -> None:
+ """
+ Called when reconnection failed for longer than `reconnect_timeout`.
+ """
+ self._reconnect_failed(self)
+
+ def run_task(self, task: "MaintenanceTask") -> None:
+ """Run a maintenance task in a worker thread."""
+ self._tasks.put_nowait(task)
+
+ def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
+ """Run a maintenance task in a worker thread in the future."""
+ self._sched.enter(delay, task.tick)
+
+ _WORKER_TIMEOUT = 60.0
+
+ @classmethod
+ def worker(cls, q: "Queue[MaintenanceTask]") -> None:
+ """Runner to execute pending maintenance task.
+
+ The function is designed to run as a separate thread.
+
+ Block on the queue *q*, run a task received. Finish running if a
+ StopWorker is received.
+ """
+ # Don't make all the workers time out at the same moment
+ timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1)
+ while True:
+ # Use a timeout to make the wait interruptible
+ try:
+ task = q.get(timeout=timeout)
+ except Empty:
+ continue
+
+ if isinstance(task, StopWorker):
+ logger.debug(
+ "terminating working thread %s",
+ threading.current_thread().name,
+ )
+ return
+
+ # Run the task. Make sure don't die in the attempt.
+ try:
+ task.run()
+ except Exception as ex:
+ logger.warning(
+ "task run %s failed: %s: %s",
+ task,
+ ex.__class__.__name__,
+ ex,
+ )
+
+ def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
+ """Return a new connection configured for the pool."""
+ self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
+ t0 = monotonic()
+ try:
+ conn: Connection[Any]
+ conn = self.connection_class.connect(self.conninfo, **kwargs)
+ except Exception:
+ self._stats[self._CONNECTIONS_ERRORS] += 1
+ raise
+ else:
+ t1 = monotonic()
+ self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0))
+
+ conn._pool = self
+
+ if self._configure:
+ self._configure(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by configure function"
+ f" {self._configure}: discarded"
+ )
+
+ # Set an expiry date, with some randomness to avoid mass reconnection
+ self._set_connection_expiry_date(conn)
+ return conn
+
+ def _add_connection(
+ self, attempt: Optional[ConnectionAttempt], growing: bool = False
+ ) -> None:
+ """Try to connect and add the connection to the pool.
+
+ If failed, reschedule a new attempt in the future for a few times, then
+ give up, decrease the pool connections number and call
+ `self.reconnect_failed()`.
+
+ """
+ now = monotonic()
+ if not attempt:
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
+
+ try:
+ conn = self._connect()
+ except Exception as ex:
+ logger.warning(f"error connecting in {self.name!r}: {ex}")
+ if attempt.time_to_give_up(now):
+ logger.warning(
+ "reconnection attempt in pool %r failed after %s sec",
+ self.name,
+ self.reconnect_timeout,
+ )
+ with self._lock:
+ self._nconns -= 1
+ # If we have given up with a growing attempt, allow a new one.
+ if growing and self._growing:
+ self._growing = False
+ self.reconnect_failed()
+ else:
+ attempt.update_delay(now)
+ self.schedule_task(
+ AddConnection(self, attempt, growing=growing),
+ attempt.delay,
+ )
+ return
+
+ logger.info("adding new connection to the pool")
+ self._add_to_pool(conn)
+ if growing:
+ with self._lock:
+ # Keep on growing if the pool is not full yet, or if there are
+ # clients waiting and the pool can extend.
+ if self._nconns < self._min_size or (
+ self._nconns < self._max_size and self._waiting
+ ):
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self.run_task(AddConnection(self, growing=True))
+ else:
+ self._growing = False
+
+ def _return_connection(self, conn: Connection[Any]) -> None:
+ """
+ Return a connection to the pool after usage.
+ """
+ self._reset_connection(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ # Connection no more in working state: create a new one.
+ self.run_task(AddConnection(self))
+ logger.warning("discarding closed connection: %s", conn)
+ return
+
+ # Check if the connection is past its best before date
+ if conn._expire_at <= monotonic():
+ self.run_task(AddConnection(self))
+ logger.info("discarding expired connection")
+ conn.close()
+ return
+
+ self._add_to_pool(conn)
+
+ def _add_to_pool(self, conn: Connection[Any]) -> None:
+ """
+ Add a connection to the pool.
+
+ The connection can be a fresh one or one already used in the pool.
+
+ If a client is already waiting for a connection pass it on, otherwise
+ put it back into the pool
+ """
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: put it back into the pool
+ self._pool.append(conn)
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is full.
+ if self._pool_full_event and len(self._pool) >= self._min_size:
+ self._pool_full_event.set()
+
+ def _reset_connection(self, conn: Connection[Any]) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ pass
+
+ elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ conn.rollback()
+ except Exception as ex:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ ex.__class__.__name__,
+ ex,
+ conn,
+ )
+ conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ conn.close()
+
+ if not conn.closed and self._reset:
+ try:
+ self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ conn.close()
+
+ def _shrink_pool(self) -> None:
+ to_close: Optional[Connection[Any]] = None
+
+ with self._lock:
+ # Reset the min number of connections used
+ nconns_min = self._nconns_min
+ self._nconns_min = len(self._pool)
+
+ # If the pool can shrink and connections were unused, drop one
+ if self._nconns > self._min_size and nconns_min > 0:
+ to_close = self._pool.popleft()
+ self._nconns -= 1
+ self._nconns_min -= 1
+
+ if to_close:
+ logger.info(
+ "shrinking pool %r to %s because %s unused connections"
+ " in the last %s sec",
+ self.name,
+ self._nconns,
+ nconns_min,
+ self.max_idle,
+ )
+ to_close.close()
+
+ def _get_measures(self) -> Dict[str, int]:
+ rv = super()._get_measures()
+ rv[self._REQUESTS_WAITING] = len(self._waiting)
+ return rv
+
+
+class WaitingClient:
+ """A position in a queue for a client waiting for a connection."""
+
+ __slots__ = ("conn", "error", "_cond")
+
+ def __init__(self) -> None:
+ self.conn: Optional[Connection[Any]] = None
+ self.error: Optional[Exception] = None
+
+ # The WaitingClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = threading.Condition()
+
+ def wait(self, timeout: float) -> Connection[Any]:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ with self._cond:
+ if not (self.conn or self.error):
+ if not self._cond.wait(timeout):
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
+ raise self.error
+
+ def set(self, conn: Connection[Any]) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out).
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out).
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.error = error
+ self._cond.notify_all()
+ return True
+
+
+class MaintenanceTask(ABC):
+ """A task to run asynchronously to maintain the pool state."""
+
+ def __init__(self, pool: "ConnectionPool"):
+ self.pool = ref(pool)
+
+ def __repr__(self) -> str:
+ pool = self.pool()
+ name = repr(pool.name) if pool else "<pool is gone>"
+ return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+ def run(self) -> None:
+ """Run the task.
+
+ This usually happens in a worker thread. Call the concrete _run()
+ implementation, if the pool is still alive.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task run discarded: %s", self)
+ return
+
+ logger.debug("task running in %s: %s", threading.current_thread().name, self)
+ self._run(pool)
+
+ def tick(self) -> None:
+ """Run the scheduled task
+
+ This function is called by the scheduler thread. Use a worker to
+ run the task for real in order to free the scheduler immediately.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task tick discarded: %s", self)
+ return
+
+ pool.run_task(self)
+
+ @abstractmethod
+ def _run(self, pool: "ConnectionPool") -> None:
+ ...
+
+
+class StopWorker(MaintenanceTask):
+ """Signal the maintenance thread to terminate."""
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pass
+
+
+class AddConnection(MaintenanceTask):
+ def __init__(
+ self,
+ pool: "ConnectionPool",
+ attempt: Optional["ConnectionAttempt"] = None,
+ growing: bool = False,
+ ):
+ super().__init__(pool)
+ self.attempt = attempt
+ self.growing = growing
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._add_connection(self.attempt, growing=self.growing)
+
+
+class ReturnConnection(MaintenanceTask):
+ """Clean up and return a connection to the pool."""
+
+ def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"):
+ super().__init__(pool)
+ self.conn = conn
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+ """If the pool can shrink, remove one connection.
+
+ Re-schedule periodically and also reset the minimum number of connections
+ in the pool.
+ """
+
+ def _run(self, pool: "ConnectionPool") -> None:
+ # Reschedule the task now so that in case of any error we don't lose
+ # the periodic run.
+ pool.schedule_task(self, pool.max_idle)
+ pool._shrink_pool()
diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py
new file mode 100644
index 0000000..0ea6e9a
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/pool_async.py
@@ -0,0 +1,784 @@
+"""
+psycopg asynchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+from abc import ABC, abstractmethod
+from time import monotonic
+from types import TracebackType
+from typing import Any, AsyncIterator, Awaitable, Callable
+from typing import Dict, List, Optional, Sequence, Type
+from weakref import ref
+from contextlib import asynccontextmanager
+
+from psycopg import errors as e
+from psycopg import AsyncConnection
+from psycopg.pq import TransactionStatus
+
+from .base import ConnectionAttempt, BasePool
+from .sched import AsyncScheduler
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
+from ._compat import Task, create_task, Deque
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
+ def __init__(
+ self,
+ conninfo: str = "",
+ *,
+ open: bool = True,
+ connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
+ configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ **kwargs: Any,
+ ):
+ self.connection_class = connection_class
+ self._configure = configure
+ self._reset = reset
+
+ # asyncio objects, created on open to attach them to the right loop.
+ self._lock: asyncio.Lock
+ self._sched: AsyncScheduler
+ self._tasks: "asyncio.Queue[MaintenanceTask]"
+
+ self._waiting = Deque["AsyncClient"]()
+
+ # to notify that the pool is full
+ self._pool_full_event: Optional[asyncio.Event] = None
+
+ self._sched_runner: Optional[Task[None]] = None
+ self._workers: List[Task[None]] = []
+
+ super().__init__(conninfo, **kwargs)
+
+ if open:
+ self._open()
+
+ async def wait(self, timeout: float = 30.0) -> None:
+ self._check_open_getconn()
+
+ async with self._lock:
+ assert not self._pool_full_event
+ if len(self._pool) >= self._min_size:
+ return
+ self._pool_full_event = asyncio.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the tasks
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ ) from None
+
+ async with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ @asynccontextmanager
+ async def connection(
+ self, timeout: Optional[float] = None
+ ) -> AsyncIterator[AsyncConnection[Any]]:
+ conn = await self.getconn(timeout=timeout)
+ t0 = monotonic()
+ try:
+ async with conn:
+ yield conn
+ finally:
+ t1 = monotonic()
+ self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
+ await self.putconn(conn)
+
+ async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ self._check_open_getconn()
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ async with self._lock:
+ conn = await self._get_ready_connection(timeout)
+ if not conn:
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = AsyncClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If there is space for the pool to grow, let's do it
+ self._maybe_grow_pool()
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if not conn:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = await pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ # Note that this property shouldn't be set while the connection is in
+ # the pool, to avoid to create a reference loop.
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ async def _get_ready_connection(
+ self, timeout: Optional[float]
+ ) -> Optional[AsyncConnection[Any]]:
+ conn: Optional[AsyncConnection[Any]] = None
+ if self._pool:
+ # Take a connection ready out of the pool
+ conn = self._pool.popleft()
+ if len(self._pool) < self._nconns_min:
+ self._nconns_min = len(self._pool)
+ elif self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has already"
+ f" {len(self._waiting)} requests waiting"
+ )
+ return conn
+
+ def _maybe_grow_pool(self) -> None:
+ # Allow only one task at time to grow the pool (or returning
+ # connections might be starved).
+ if self._nconns < self._max_size and not self._growing:
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self._growing = True
+ self.run_task(AddConnection(self, growing=True))
+
+ async def putconn(self, conn: AsyncConnection[Any]) -> None:
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+ if await self._maybe_close_connection(conn):
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate task
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ await self._return_connection(conn)
+
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ # If the pool is closed just close the connection instead of returning
+ # it to the pool. For extra refcare remove the pool reference from it.
+ if not self._closed:
+ return False
+
+ conn._pool = None
+ await conn.close()
+ return True
+
+ async def open(self, wait: bool = False, timeout: float = 30.0) -> None:
+ # Make sure the lock is created after there is an event loop
+ try:
+ self._lock
+ except AttributeError:
+ self._lock = asyncio.Lock()
+
+ async with self._lock:
+ self._open()
+
+ if wait:
+ await self.wait(timeout=timeout)
+
+ def _open(self) -> None:
+ if not self._closed:
+ return
+
+ # Throw a RuntimeError if the pool is open outside a running loop.
+ asyncio.get_running_loop()
+
+ self._check_open()
+
+ # Create these objects now to attach them to the right loop.
+ # See #219
+ self._tasks = asyncio.Queue()
+ self._sched = AsyncScheduler()
+ # This has been most likely, but not necessarily, created in `open()`.
+ try:
+ self._lock
+ except AttributeError:
+ self._lock = asyncio.Lock()
+
+ self._closed = False
+ self._opened = True
+
+ self._start_workers()
+ self._start_initial_tasks()
+
+ def _start_workers(self) -> None:
+ self._sched_runner = create_task(
+ self._sched.run(), name=f"{self.name}-scheduler"
+ )
+ for i in range(self.num_workers):
+ t = create_task(
+ self.worker(self._tasks),
+ name=f"{self.name}-worker-{i}",
+ )
+ self._workers.append(t)
+
+ def _start_initial_tasks(self) -> None:
+ # populate the pool with initial min_size connections in background
+ for i in range(self._nconns):
+ self.run_task(AddConnection(self))
+
+ # Schedule a task to shrink the pool if connections over min_size have
+ # remained unused.
+ self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
+
+ async def close(self, timeout: float = 5.0) -> None:
+ if self._closed:
+ return
+
+ async with self._lock:
+ self._closed = True
+ logger.debug("pool %r closed", self.name)
+
+ # Take waiting client and pool connections out of the state
+ waiting = list(self._waiting)
+ self._waiting.clear()
+ connections = list(self._pool)
+ self._pool.clear()
+
+ # Now that the flag _closed is set, getconn will fail immediately,
+ # putconn will just close the returned connection.
+ await self._stop_workers(waiting, connections, timeout)
+
+ async def _stop_workers(
+ self,
+ waiting_clients: Sequence["AsyncClient"] = (),
+ connections: Sequence[AsyncConnection[Any]] = (),
+ timeout: float = 0.0,
+ ) -> None:
+ # Stop the scheduler
+ await self._sched.enter(0, None)
+
+ # Stop the worker tasks
+ workers, self._workers = self._workers[:], []
+ for w in workers:
+ self.run_task(StopWorker(self))
+
+ # Signal to eventual clients in the queue that business is closed.
+ for pos in waiting_clients:
+ await pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+ # Close the connections still in the pool
+ for conn in connections:
+ await conn.close()
+
+ # Wait for the worker tasks to terminate
+ assert self._sched_runner is not None
+ sched_runner, self._sched_runner = self._sched_runner, None
+ wait = asyncio.gather(sched_runner, *workers)
+ try:
+ if timeout > 0:
+ await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+ else:
+ await wait
+ except asyncio.TimeoutError:
+ logger.warning(
+ "couldn't stop pool %r tasks within %s seconds",
+ self.name,
+ timeout,
+ )
+
+ async def __aenter__(self) -> "AsyncConnectionPool":
+ await self.open()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ await self.close()
+
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ ngrow = max(0, min_size - self._min_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ async with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+ self._nconns += ngrow
+
+ for i in range(ngrow):
+ self.run_task(AddConnection(self))
+
+ async def check(self) -> None:
+ async with self._lock:
+ conns = list(self._pool)
+ self._pool.clear()
+
+ # Give a chance to the pool to grow if it has no connection.
+ # In case there are enough connection, or the pool is already
+ # growing, this is a no-op.
+ self._maybe_grow_pool()
+
+ while conns:
+ conn = conns.pop()
+ try:
+ await conn.execute("SELECT 1")
+ if conn.pgconn.transaction_status == TransactionStatus.INTRANS:
+ await conn.rollback()
+ except Exception:
+ self._stats[self._CONNECTIONS_LOST] += 1
+ logger.warning("discarding broken connection: %s", conn)
+ self.run_task(AddConnection(self))
+ else:
+ await self._add_to_pool(conn)
+
+ def reconnect_failed(self) -> None:
+ """
+ Called when reconnection failed for longer than `reconnect_timeout`.
+ """
+ self._reconnect_failed(self)
+
+ def run_task(self, task: "MaintenanceTask") -> None:
+ """Run a maintenance task in a worker."""
+ self._tasks.put_nowait(task)
+
+ async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
+ """Run a maintenance task in a worker in the future."""
+ await self._sched.enter(delay, task.tick)
+
+ @classmethod
+ async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None:
+ """Runner to execute pending maintenance task.
+
+ The function is designed to run as a task.
+
+ Block on the queue *q*, run a task received. Finish running if a
+ StopWorker is received.
+ """
+ while True:
+ task = await q.get()
+
+ if isinstance(task, StopWorker):
+ logger.debug("terminating working task")
+ return
+
+ # Run the task. Make sure don't die in the attempt.
+ try:
+ await task.run()
+ except Exception as ex:
+ logger.warning(
+ "task run %s failed: %s: %s",
+ task,
+ ex.__class__.__name__,
+ ex,
+ )
+
+ async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
+ t0 = monotonic()
+ try:
+ conn: AsyncConnection[Any]
+ conn = await self.connection_class.connect(self.conninfo, **kwargs)
+ except Exception:
+ self._stats[self._CONNECTIONS_ERRORS] += 1
+ raise
+ else:
+ t1 = monotonic()
+ self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0))
+
+ conn._pool = self
+
+ if self._configure:
+ await self._configure(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by configure function"
+ f" {self._configure}: discarded"
+ )
+
+ # Set an expiry date, with some randomness to avoid mass reconnection
+ self._set_connection_expiry_date(conn)
+ return conn
+
+ async def _add_connection(
+ self, attempt: Optional[ConnectionAttempt], growing: bool = False
+ ) -> None:
+ """Try to connect and add the connection to the pool.
+
+ If failed, reschedule a new attempt in the future for a few times, then
+ give up, decrease the pool connections number and call
+ `self.reconnect_failed()`.
+
+ """
+ now = monotonic()
+ if not attempt:
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
+
+ try:
+ conn = await self._connect()
+ except Exception as ex:
+ logger.warning(f"error connecting in {self.name!r}: {ex}")
+ if attempt.time_to_give_up(now):
+ logger.warning(
+ "reconnection attempt in pool %r failed after %s sec",
+ self.name,
+ self.reconnect_timeout,
+ )
+ async with self._lock:
+ self._nconns -= 1
+ # If we have given up with a growing attempt, allow a new one.
+ if growing and self._growing:
+ self._growing = False
+ self.reconnect_failed()
+ else:
+ attempt.update_delay(now)
+ await self.schedule_task(
+ AddConnection(self, attempt, growing=growing),
+ attempt.delay,
+ )
+ return
+
+ logger.info("adding new connection to the pool")
+ await self._add_to_pool(conn)
+ if growing:
+ async with self._lock:
+ # Keep on growing if the pool is not full yet, or if there are
+ # clients waiting and the pool can extend.
+ if self._nconns < self._min_size or (
+ self._nconns < self._max_size and self._waiting
+ ):
+ self._nconns += 1
+ logger.info("growing pool %r to %s", self.name, self._nconns)
+ self.run_task(AddConnection(self, growing=True))
+ else:
+ self._growing = False
+
+ async def _return_connection(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Return a connection to the pool after usage.
+ """
+ await self._reset_connection(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ # Connection no more in working state: create a new one.
+ self.run_task(AddConnection(self))
+ logger.warning("discarding closed connection: %s", conn)
+ return
+
+ # Check if the connection is past its best before date
+ if conn._expire_at <= monotonic():
+ self.run_task(AddConnection(self))
+ logger.info("discarding expired connection")
+ await conn.close()
+ return
+
+ await self._add_to_pool(conn)
+
+ async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Add a connection to the pool.
+
+ The connection can be a fresh one or one already used in the pool.
+
+ If a client is already waiting for a connection pass it on, otherwise
+ put it back into the pool
+ """
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: put it back into the pool
+ self._pool.append(conn)
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is full.
+ if self._pool_full_event and len(self._pool) >= self._min_size:
+ self._pool_full_event.set()
+
+ async def _reset_connection(self, conn: AsyncConnection[Any]) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ pass
+
+ elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ await conn.rollback()
+ except Exception as ex:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ ex.__class__.__name__,
+ ex,
+ conn,
+ )
+ await conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ await conn.close()
+
+ if not conn.closed and self._reset:
+ try:
+ await self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ sname = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {sname} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ await conn.close()
+
+ async def _shrink_pool(self) -> None:
+ to_close: Optional[AsyncConnection[Any]] = None
+
+ async with self._lock:
+ # Reset the min number of connections used
+ nconns_min = self._nconns_min
+ self._nconns_min = len(self._pool)
+
+ # If the pool can shrink and connections were unused, drop one
+ if self._nconns > self._min_size and nconns_min > 0:
+ to_close = self._pool.popleft()
+ self._nconns -= 1
+ self._nconns_min -= 1
+
+ if to_close:
+ logger.info(
+ "shrinking pool %r to %s because %s unused connections"
+ " in the last %s sec",
+ self.name,
+ self._nconns,
+ nconns_min,
+ self.max_idle,
+ )
+ await to_close.close()
+
+ def _get_measures(self) -> Dict[str, int]:
+ rv = super()._get_measures()
+ rv[self._REQUESTS_WAITING] = len(self._waiting)
+ return rv
+
+
+class AsyncClient:
+ """A position in a queue for a client waiting for a connection."""
+
+ __slots__ = ("conn", "error", "_cond")
+
+ def __init__(self) -> None:
+ self.conn: Optional[AsyncConnection[Any]] = None
+ self.error: Optional[Exception] = None
+
+ # The AsyncClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = asyncio.Condition()
+
+ async def wait(self, timeout: float) -> AsyncConnection[Any]:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ async with self._cond:
+ if not (self.conn or self.error):
+ try:
+ await asyncio.wait_for(self._cond.wait(), timeout)
+ except asyncio.TimeoutError:
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
+ raise self.error
+
+ async def set(self, conn: AsyncConnection[Any]) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ async def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.error = error
+ self._cond.notify_all()
+ return True
+
+
+class MaintenanceTask(ABC):
+ """A task to run asynchronously to maintain the pool state."""
+
+ def __init__(self, pool: "AsyncConnectionPool"):
+ self.pool = ref(pool)
+
+ def __repr__(self) -> str:
+ pool = self.pool()
+ name = repr(pool.name) if pool else "<pool is gone>"
+ return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+ async def run(self) -> None:
+ """Run the task.
+
+ This usually happens in a worker. Call the concrete _run()
+ implementation, if the pool is still alive.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task run discarded: %s", self)
+ return
+
+ await self._run(pool)
+
+ async def tick(self) -> None:
+ """Run the scheduled task
+
+ This function is called by the scheduler task. Use a worker to
+ run the task for real in order to free the scheduler immediately.
+ """
+ pool = self.pool()
+ if not pool or pool.closed:
+ # Pool is no more working. Quietly discard the operation.
+ logger.debug("task tick discarded: %s", self)
+ return
+
+ pool.run_task(self)
+
+ @abstractmethod
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ ...
+
+
+class StopWorker(MaintenanceTask):
+ """Signal the maintenance worker to terminate."""
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ pass
+
+
+class AddConnection(MaintenanceTask):
+ def __init__(
+ self,
+ pool: "AsyncConnectionPool",
+ attempt: Optional["ConnectionAttempt"] = None,
+ growing: bool = False,
+ ):
+ super().__init__(pool)
+ self.attempt = attempt
+ self.growing = growing
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool._add_connection(self.attempt, growing=self.growing)
+
+
+class ReturnConnection(MaintenanceTask):
+ """Clean up and return a connection to the pool."""
+
+ def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
+ super().__init__(pool)
+ self.conn = conn
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+ """If the pool can shrink, remove one connection.
+
+ Re-schedule periodically and also reset the minimum number of connections
+ in the pool.
+ """
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ # Reschedule the task now so that in case of any error we don't lose
+ # the periodic run.
+ await pool.schedule_task(self, pool.max_idle)
+ await pool._shrink_pool()
+
+
+class Schedule(MaintenanceTask):
+ """Schedule a task in the pool scheduler.
+
+ This task is a trampoline to allow to use a sync call (pool.run_task)
+ to execute an async one (pool.schedule_task).
+ """
+
+ def __init__(
+ self,
+ pool: "AsyncConnectionPool",
+ task: MaintenanceTask,
+ delay: float,
+ ):
+ super().__init__(pool)
+ self.task = task
+ self.delay = delay
+
+ async def _run(self, pool: "AsyncConnectionPool") -> None:
+ await pool.schedule_task(self.task, self.delay)
diff --git a/psycopg_pool/psycopg_pool/py.typed b/psycopg_pool/psycopg_pool/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/py.typed
diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py
new file mode 100644
index 0000000..ca26007
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/sched.py
@@ -0,0 +1,177 @@
+"""
+A minimal scheduler to schedule tasks run in the future.
+
+Inspired to the standard library `sched.scheduler`, but designed for
+multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in
+front of the one currently running and `Scheduler.run()` can be left running
+without any task scheduled.
+
+Tasks are called "Task", not "Event", here, because we actually make use of
+`threading.Event` and the two would be confusing.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+import threading
+from time import monotonic
+from heapq import heappush, heappop
+from typing import Any, Callable, List, Optional, NamedTuple
+
+logger = logging.getLogger(__name__)
+
+
+class Task(NamedTuple):
+ time: float
+ action: Optional[Callable[[], Any]]
+
+ def __eq__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time == other.time
+
+ def __lt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time < other.time
+
+ def __le__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time <= other.time
+
+ def __gt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time > other.time
+
+ def __ge__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time >= other.time
+
+
+class Scheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = threading.RLock()
+ self._event = threading.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return self.enterabs(time, action)
+
+ def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ if not task.action:
+ break
+ try:
+ task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ self._event.wait(timeout=delay)
+
+
+class AsyncScheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = asyncio.Lock()
+ self._event = asyncio.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return await self.enterabs(time, action)
+
+ async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ async with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ async def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ async with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ if not task.action:
+ break
+ try:
+ await task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ try:
+ await asyncio.wait_for(self._event.wait(), delay)
+ except asyncio.TimeoutError:
+ pass
diff --git a/psycopg_pool/psycopg_pool/version.py b/psycopg_pool/psycopg_pool/version.py
new file mode 100644
index 0000000..fc99bbd
--- /dev/null
+++ b/psycopg_pool/psycopg_pool/version.py
@@ -0,0 +1,13 @@
+"""
+psycopg pool version file.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+# Use a versioning scheme as defined in
+# https://www.python.org/dev/peps/pep-0440/
+
+# STOP AND READ! if you change:
+__version__ = "3.1.5"
+# also change:
+# - `docs/news_pool.rst` to declare this version current or unreleased
diff --git a/psycopg_pool/setup.cfg b/psycopg_pool/setup.cfg
new file mode 100644
index 0000000..1a3274e
--- /dev/null
+++ b/psycopg_pool/setup.cfg
@@ -0,0 +1,45 @@
+[metadata]
+name = psycopg-pool
+description = Connection Pool for Psycopg
+url = https://psycopg.org/psycopg3/
+author = Daniele Varrazzo
+author_email = daniele.varrazzo@gmail.com
+license = GNU Lesser General Public License v3 (LGPLv3)
+
+project_urls =
+ Homepage = https://psycopg.org/
+ Code = https://github.com/psycopg/psycopg
+ Issue Tracker = https://github.com/psycopg/psycopg/issues
+ Download = https://pypi.org/project/psycopg-pool/
+
+classifiers =
+ Development Status :: 5 - Production/Stable
+ Intended Audience :: Developers
+ License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)
+ Operating System :: MacOS :: MacOS X
+ Operating System :: Microsoft :: Windows
+ Operating System :: POSIX
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
+ Topic :: Database
+ Topic :: Database :: Front-Ends
+ Topic :: Software Development
+ Topic :: Software Development :: Libraries :: Python Modules
+
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+license_files = LICENSE.txt
+
+[options]
+python_requires = >= 3.7
+packages = find:
+zip_safe = False
+install_requires =
+ typing-extensions >= 3.10
+
+[options.package_data]
+psycopg_pool = py.typed
diff --git a/psycopg_pool/setup.py b/psycopg_pool/setup.py
new file mode 100644
index 0000000..771847d
--- /dev/null
+++ b/psycopg_pool/setup.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+"""
+PostgreSQL database adapter for Python - Connection Pool
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import os
+import re
+from setuptools import setup
+
+# Move to the directory of setup.py: executing this file from another location
+# (e.g. from the project root) will fail
+here = os.path.abspath(os.path.dirname(__file__))
+if os.path.abspath(os.getcwd()) != here:
+ os.chdir(here)
+
+with open("psycopg_pool/version.py") as f:
+ data = f.read()
+ m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data)
+ if not m:
+ raise Exception(f"cannot find version in {f.name}")
+ version = m.group(1)
+
+
+setup(version=version)