diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /psycopg | |
parent | Initial commit. (diff) | |
download | psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
123 files changed, 29329 insertions, 0 deletions
diff --git a/psycopg/.flake8 b/psycopg/.flake8 new file mode 100644 index 0000000..67fb024 --- /dev/null +++ b/psycopg/.flake8 @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 +per-file-ignores = + # Autogenerated section + psycopg/errors.py: E125, E128, E302 diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg/README.rst b/psycopg/README.rst new file mode 100644 index 0000000..45eeac3 --- /dev/null +++ b/psycopg/README.rst @@ -0,0 +1,31 @@ +Psycopg 3: PostgreSQL database adapter for Python +================================================= + +Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python. + +This distribution contains the pure Python package ``psycopg``. + + +Installation +------------ + +In short, run the following:: + + pip install --upgrade pip # to upgrade pip + pip install "psycopg[binary,pool]" # to install package and dependencies + +If something goes wrong, and for more information about installation, please +check out the `Installation documentation`__. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html# + + +Hacking +------- + +For development information check out `the project readme`__. + +.. __: https://github.com/psycopg/psycopg#readme + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py new file mode 100644 index 0000000..baadf30 --- /dev/null +++ b/psycopg/psycopg/__init__.py @@ -0,0 +1,110 @@ +""" +psycopg -- PostgreSQL database adapter for Python +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from . import pq # noqa: F401 import early to stabilize side effects +from . import types +from . import postgres +from ._tpc import Xid +from .copy import Copy, AsyncCopy +from ._enums import IsolationLevel +from .cursor import Cursor +from .errors import Warning, Error, InterfaceError, DatabaseError +from .errors import DataError, OperationalError, IntegrityError +from .errors import InternalError, ProgrammingError, NotSupportedError +from ._column import Column +from .conninfo import ConnectionInfo +from ._pipeline import Pipeline, AsyncPipeline +from .connection import BaseConnection, Connection, Notify +from .transaction import Rollback, Transaction, AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor, ServerCursor +from .client_cursor import AsyncClientCursor, ClientCursor +from .connection_async import AsyncConnection + +from . import dbapi20 +from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING +from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks +from .dbapi20 import Timestamp, TimestampFromTicks + +from .version import __version__ as __version__ # noqa: F401 + +# Set the logger to a quiet default, can be enabled if needed +logger = logging.getLogger("psycopg") +if logger.level == logging.NOTSET: + logger.setLevel(logging.WARNING) + +# DBAPI compliance +connect = Connection.connect +apilevel = "2.0" +threadsafety = 2 +paramstyle = "pyformat" + +# register default adapters for PostgreSQL +adapters = postgres.adapters # exposed by the package +postgres.register_default_adapters(adapters) + +# After the default ones, because these can deal with the bytea oid better +dbapi20.register_dbapi20_adapters(adapters) + +# Must come after all the types have been registered +types.array.register_all_arrays(adapters) + +# Note: defining the exported methods helps both Sphynx in documenting that +# this is the canonical place to obtain them and should be used by MyPy too, +# so that function signatures are consistent with the documentation. +__all__ = [ + "AsyncClientCursor", + "AsyncConnection", + "AsyncCopy", + "AsyncCursor", + "AsyncPipeline", + "AsyncServerCursor", + "AsyncTransaction", + "BaseConnection", + "ClientCursor", + "Column", + "Connection", + "ConnectionInfo", + "Copy", + "Cursor", + "IsolationLevel", + "Notify", + "Pipeline", + "Rollback", + "ServerCursor", + "Transaction", + "Xid", + # DBAPI exports + "connect", + "apilevel", + "threadsafety", + "paramstyle", + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + # DBAPI type constructors and singletons + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "DATETIME", + "NUMBER", + "ROWID", + "STRING", +] diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py new file mode 100644 index 0000000..a3a6ef8 --- /dev/null +++ b/psycopg/psycopg/_adapters_map.py @@ -0,0 +1,289 @@ +""" +Mapping from types/oids to Dumpers/Loaders +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import cast, TYPE_CHECKING + +from . import pq +from . import errors as e +from .abc import Dumper, Loader +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg +from ._typeinfo import TypesRegistry + +if TYPE_CHECKING: + from .connection import BaseConnection + +RV = TypeVar("RV") + + +class AdaptersMap: + r""" + Establish how types should be converted between Python and PostgreSQL in + an `~psycopg.abc.AdaptContext`. + + `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to + define how Python types are converted to PostgreSQL, and maps OIDs to + `~psycopg.adapt.Loader` classes to establish how query results are + converted to Python. + + Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how + types are converted in that context, exposed as the + `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows + to customise adaptation in a context without changing separated contexts. + + When a context is created from another context (for instance when a + `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's + `!adapters` are used as template for the child's `!adapters`, so that every + cursor created from the same connection use the connection's types + configuration, but separate connections have independent mappings. + + Once created, `!AdaptersMap` are independent. This means that objects + already created are not affected if a wider scope (e.g. the global one) is + changed. + + The connections adapters are initialised using a global `!AdptersMap` + template, exposed as `psycopg.adapters`: changing such mapping allows to + customise the type mapping for every connections created afterwards. + + The object can start empty or copy from another object of the same class. + Copies are copy-on-write: if the maps are updated make a copy. This way + extending e.g. global map by a connection or a connection map from a cursor + is cheap: a copy is only made on customisation. + """ + + __module__ = "psycopg.adapt" + + types: TypesRegistry + + _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]] + _dumpers_by_oid: List[Dict[int, Type[Dumper]]] + _loaders: List[Dict[int, Type[Loader]]] + + # Record if a dumper or loader has an optimised version. + _optimised: Dict[type, type] = {} + + def __init__( + self, + template: Optional["AdaptersMap"] = None, + types: Optional[TypesRegistry] = None, + ): + if template: + self._dumpers = template._dumpers.copy() + self._own_dumpers = _dumpers_shared.copy() + template._own_dumpers = _dumpers_shared.copy() + + self._dumpers_by_oid = template._dumpers_by_oid[:] + self._own_dumpers_by_oid = [False, False] + template._own_dumpers_by_oid = [False, False] + + self._loaders = template._loaders[:] + self._own_loaders = [False, False] + template._own_loaders = [False, False] + + self.types = TypesRegistry(template.types) + + else: + self._dumpers = {fmt: {} for fmt in PyFormat} + self._own_dumpers = _dumpers_owned.copy() + + self._dumpers_by_oid = [{}, {}] + self._own_dumpers_by_oid = [True, True] + + self._loaders = [{}, {}] + self._own_loaders = [True, True] + + self.types = types or TypesRegistry() + + # implement the AdaptContext protocol too + @property + def adapters(self) -> "AdaptersMap": + return self + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return None + + def register_dumper( + self, cls: Union[type, str, None], dumper: Type[Dumper] + ) -> None: + """ + Configure the context to use `!dumper` to convert objects of type `!cls`. + + If two dumpers with different `~Dumper.format` are registered for the + same type, the last one registered will be chosen when the query + doesn't specify a format (i.e. when the value is used with a ``%s`` + "`~PyFormat.AUTO`" placeholder). + + :param cls: The type to manage. + :param dumper: The dumper to register for `!cls`. + + If `!cls` is specified as string it will be lazy-loaded, so that it + will be possible to register it without importing it before. In this + case it should be the fully qualified name of the object (e.g. + ``"uuid.UUID"``). + + If `!cls` is None, only use the dumper when looking up using + `get_dumper_by_oid()`, which happens when we know the Postgres type to + adapt to, but not the Python type that will be adapted (e.g. in COPY + after using `~psycopg.Copy.set_types()`). + + """ + if not (cls is None or isinstance(cls, (str, type))): + raise TypeError( + f"dumpers should be registered on classes, got {cls} instead" + ) + + if _psycopg: + dumper = self._get_optimised(dumper) + + # Register the dumper both as its format and as auto + # so that the last dumper registered is used in auto (%s) format + if cls: + for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO): + if not self._own_dumpers[fmt]: + self._dumpers[fmt] = self._dumpers[fmt].copy() + self._own_dumpers[fmt] = True + + self._dumpers[fmt][cls] = dumper + + # Register the dumper by oid, if the oid of the dumper is fixed + if dumper.oid: + if not self._own_dumpers_by_oid[dumper.format]: + self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[ + dumper.format + ].copy() + self._own_dumpers_by_oid[dumper.format] = True + + self._dumpers_by_oid[dumper.format][dumper.oid] = dumper + + def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None: + """ + Configure the context to use `!loader` to convert data of oid `!oid`. + + :param oid: The PostgreSQL OID or type name to manage. + :param loader: The loar to register for `!oid`. + + If `oid` is specified as string, it refers to a type name, which is + looked up in the `types` registry. ` + + """ + if isinstance(oid, str): + oid = self.types[oid].oid + if not isinstance(oid, int): + raise TypeError(f"loaders should be registered on oid, got {oid} instead") + + if _psycopg: + loader = self._get_optimised(loader) + + fmt = loader.format + if not self._own_loaders[fmt]: + self._loaders[fmt] = self._loaders[fmt].copy() + self._own_loaders[fmt] = True + + self._loaders[fmt][oid] = loader + + def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]: + """ + Return the dumper class for the given type and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param cls: The class to adapt. + :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`, + use the last one of the dumpers registered on `!cls`. + """ + try: + dmap = self._dumpers[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + # Look for the right class, including looking at superclasses + for scls in cls.__mro__: + if scls in dmap: + return dmap[scls] + + # If the adapter is not found, look for its name as a string + fqn = scls.__module__ + "." + scls.__qualname__ + if fqn in dmap: + # Replace the class name with the class itself + d = dmap[scls] = dmap.pop(fqn) + return d + + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'" + f" (format: {PyFormat(format).name})" + ) + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]: + """ + Return the dumper class for the given oid and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param oid: The oid of the type to dump to. + :param format: The format to dump to. + """ + try: + dmap = self._dumpers_by_oid[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + try: + return dmap[oid] + except KeyError: + info = self.types.get(oid) + if info: + msg = ( + f"cannot find a dumper for type {info.name} (oid {oid})" + f" format {pq.Format(format).name}" + ) + else: + msg = ( + f"cannot find a dumper for unknown type with oid {oid}" + f" format {pq.Format(format).name}" + ) + raise e.ProgrammingError(msg) + + def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]: + """ + Return the loader class for the given oid and format. + + Return `!None` if not found. + + :param oid: The oid of the type to load. + :param format: The format to load from. + """ + return self._loaders[format].get(oid) + + @classmethod + def _get_optimised(self, cls: Type[RV]) -> Type[RV]: + """Return the optimised version of a Dumper or Loader class. + + Return the input class itself if there is no optimised version. + """ + try: + return self._optimised[cls] + except KeyError: + pass + + # Check if the class comes from psycopg.types and there is a class + # with the same name in psycopg_c._psycopg. + from psycopg import types + + if cls.__module__.startswith(types.__name__): + new = cast(Type[RV], getattr(_psycopg, cls.__name__, None)) + if new: + self._optimised[cls] = new + return new + + self._optimised[cls] = cls + return cls + + +# Micro-optimization: copying these objects is faster than creating new dicts +_dumpers_owned = dict.fromkeys(PyFormat, True) +_dumpers_shared = dict.fromkeys(PyFormat, False) diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py new file mode 100644 index 0000000..288ef1b --- /dev/null +++ b/psycopg/psycopg/_cmodule.py @@ -0,0 +1,24 @@ +""" +Simplify access to the _psycopg module +""" + +# Copyright (C) 2021 The Psycopg Team + +from typing import Optional + +from . import pq + +__version__: Optional[str] = None + +# Note: "c" must the first attempt so that mypy associates the variable the +# right module interface. It will not result Optional, but hey. +if pq.__impl__ == "c": + from psycopg_c import _psycopg as _psycopg + from psycopg_c import __version__ as __version__ # noqa: F401 +elif pq.__impl__ == "binary": + from psycopg_binary import _psycopg as _psycopg # type: ignore + from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401 +elif pq.__impl__ == "python": + _psycopg = None # type: ignore +else: + raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}") diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py new file mode 100644 index 0000000..9e4e735 --- /dev/null +++ b/psycopg/psycopg/_column.py @@ -0,0 +1,143 @@ +""" +The Column object in Cursor.description +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING +from operator import attrgetter + +if TYPE_CHECKING: + from .cursor import BaseCursor + + +class ColumnData(NamedTuple): + ftype: int + fmod: int + fsize: int + + +class Column(Sequence[Any]): + + __module__ = "psycopg" + + def __init__(self, cursor: "BaseCursor[Any, Any]", index: int): + res = cursor.pgresult + assert res + + fname = res.fname(index) + if fname: + self._name = fname.decode(cursor._encoding) + else: + # COPY_OUT results have columns but no name + self._name = f"column_{index + 1}" + + self._data = ColumnData( + ftype=res.ftype(index), + fmod=res.fmod(index), + fsize=res.fsize(index), + ) + self._type = cursor.adapters.types.get(self._data.ftype) + + _attrs = tuple( + attrgetter(attr) + for attr in """ + name type_code display_size internal_size precision scale null_ok + """.split() + ) + + def __repr__(self) -> str: + return ( + f"<Column {self.name!r}," + f" type: {self._type_display()} (oid: {self.type_code})>" + ) + + def __len__(self) -> int: + return 7 + + def _type_display(self) -> str: + parts = [] + parts.append(self._type.name if self._type else str(self.type_code)) + + mod1 = self.precision + if mod1 is None: + mod1 = self.display_size + if mod1: + parts.append(f"({mod1}") + if self.scale: + parts.append(f", {self.scale}") + parts.append(")") + + if self._type and self.type_code == self._type.array_oid: + parts.append("[]") + + return "".join(parts) + + def __getitem__(self, index: Any) -> Any: + if isinstance(index, slice): + return tuple(getter(self) for getter in self._attrs[index]) + else: + return self._attrs[index](self) + + @property + def name(self) -> str: + """The name of the column.""" + return self._name + + @property + def type_code(self) -> int: + """The numeric OID of the column.""" + return self._data.ftype + + @property + def display_size(self) -> Optional[int]: + """The field size, for :sql:`varchar(n)`, None otherwise.""" + if not self._type: + return None + + if self._type.name in ("varchar", "char"): + fmod = self._data.fmod + if fmod >= 0: + return fmod - 4 + + return None + + @property + def internal_size(self) -> Optional[int]: + """The internal field size for fixed-size types, None otherwise.""" + fsize = self._data.fsize + return fsize if fsize >= 0 else None + + @property + def precision(self) -> Optional[int]: + """The number of digits for fixed precision types.""" + if not self._type: + return None + + dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval") + if self._type.name == "numeric": + fmod = self._data.fmod + if fmod >= 0: + return fmod >> 16 + + elif self._type.name in dttypes: + fmod = self._data.fmod + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def scale(self) -> Optional[int]: + """The number of digits after the decimal point if available.""" + if self._type and self._type.name == "numeric": + fmod = self._data.fmod - 4 + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def null_ok(self) -> Optional[bool]: + """Always `!None`""" + return None diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py new file mode 100644 index 0000000..7dbae79 --- /dev/null +++ b/psycopg/psycopg/_compat.py @@ -0,0 +1,72 @@ +""" +compatibility functions for different Python versions +""" + +# Copyright (C) 2021 The Psycopg Team + +import sys +import asyncio +from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar + +# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it. +# For this raisin it must be imported directly from typing_extension where used. +# See https://github.com/microsoft/pyright/issues/4197 +from typing_extensions import TypeAlias + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +T = TypeVar("T") +FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] + +if sys.version_info >= (3, 8): + create_task = asyncio.create_task + from math import prod + +else: + + def create_task( + coro: FutureT[T], name: Optional[str] = None + ) -> "asyncio.Future[T]": + return asyncio.create_task(coro) + + from functools import reduce + + def prod(seq: Sequence[int]) -> int: + return reduce(int.__mul__, seq, 1) + + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo + from functools import cache + from collections import Counter, deque as Deque +else: + from typing import Counter, Deque + from functools import lru_cache + from backports.zoneinfo import ZoneInfo + + cache = lru_cache(maxsize=None) + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +if sys.version_info >= (3, 11): + from typing import LiteralString +else: + from typing_extensions import LiteralString + +__all__ = [ + "Counter", + "Deque", + "LiteralString", + "Protocol", + "TypeGuard", + "ZoneInfo", + "cache", + "create_task", + "prod", +] diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py new file mode 100644 index 0000000..1e146ba --- /dev/null +++ b/psycopg/psycopg/_dns.py @@ -0,0 +1,223 @@ +# type: ignore # dnspython is currently optional and mypy fails if missing +""" +DNS query support +""" + +# Copyright (C) 2021 The Psycopg Team + +import os +import re +import warnings +from random import randint +from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence +from typing import TYPE_CHECKING +from collections import defaultdict + +try: + from dns.resolver import Resolver, Cache + from dns.asyncresolver import Resolver as AsyncResolver + from dns.exception import DNSException +except ImportError: + raise ImportError( + "the module psycopg._dns requires the package 'dnspython' installed" + ) + +from . import errors as e +from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_ + +if TYPE_CHECKING: + from dns.rdtypes.IN.SRV import SRV + +resolver = Resolver() +resolver.cache = Cache() + +async_resolver = AsyncResolver() +async_resolver.cache = Cache() + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + .. deprecated:: 3.1 + The use of this function is not necessary anymore, because + `psycopg.AsyncConnection.connect()` performs non-blocking name + resolution automatically. + """ + warnings.warn( + "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore", + DeprecationWarning, + ) + return await resolve_hostaddr_async_(params) + + +def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]: + """Apply SRV DNS lookup as defined in :RFC:`2782`.""" + return Rfc2782Resolver().resolve(params) + + +async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]: + """Async equivalent of `resolve_srv()`.""" + return await Rfc2782Resolver().resolve_async(params) + + +class HostPort(NamedTuple): + host: str + port: str + totry: bool = False + target: Optional[str] = None + + +class Rfc2782Resolver: + """Implement SRV RR Resolution as per RFC 2782 + + The class is organised to minimise code duplication between the sync and + the async paths. + """ + + re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)") + + def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(self._resolve_srv(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(await self._resolve_srv_async(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]: + """ + Return the list of host, and for each host if SRV lookup must be tried. + + Return an empty list if no lookup is requested. + """ + # If hostaddr is defined don't do any resolution. + if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")): + return [] + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") + + if len(ports_in) == 1: + # If only one port is specified, it applies to all the hosts. + ports_in *= len(hosts_in) + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + + out = [] + srv_found = False + for host, port in zip(hosts_in, ports_in): + m = self.re_srv_rr.match(host) + if m or port.lower() == "srv": + srv_found = True + target = m.group("target") if m else None + hp = HostPort(host=host, port=port, totry=True, target=target) + else: + hp = HostPort(host=host, port=port) + out.append(hp) + + return out if srv_found else [] + + def _resolve_srv(self, hp: HostPort) -> List[HostPort]: + try: + ans = resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]: + try: + ans = await async_resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + def _get_solved_entries( + self, hp: HostPort, entries: "Sequence[SRV]" + ) -> List[HostPort]: + if not entries: + # No SRV entry found. Delegate the libpq a QNAME=target lookup + if hp.target and hp.port.lower() != "srv": + return [HostPort(host=hp.target, port=hp.port)] + else: + return [] + + # If there is precisely one SRV RR, and its Target is "." (the root + # domain), abort. + if len(entries) == 1 and str(entries[0].target) == ".": + return [] + + return [ + HostPort(host=str(entry.target).rstrip("."), port=str(entry.port)) + for entry in self.sort_rfc2782(entries) + ] + + def _return_params( + self, params: Dict[str, Any], hps: List[HostPort] + ) -> Dict[str, Any]: + if not hps: + # Nothing found, we ended up with an empty list + raise e.OperationalError("no host found after SRV RR lookup") + + out = params.copy() + out["host"] = ",".join(hp.host for hp in hps) + out["port"] = ",".join(str(hp.port) for hp in hps) + return out + + def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]": + """ + Implement the priority/weight ordering defined in RFC 2782. + """ + # Divide the entries by priority: + priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list) + out: "List[SRV]" = [] + for entry in ans: + priorities[entry.priority].append(entry) + + for pri, entries in sorted(priorities.items()): + if len(entries) == 1: + out.append(entries[0]) + continue + + entries.sort(key=lambda ent: ent.weight) + total_weight = sum(ent.weight for ent in entries) + while entries: + r = randint(0, total_weight) + csum = 0 + for i, ent in enumerate(entries): + csum += ent.weight + if csum >= r: + break + out.append(ent) + total_weight -= ent.weight + del entries[i] + + return out diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py new file mode 100644 index 0000000..c584b26 --- /dev/null +++ b/psycopg/psycopg/_encodings.py @@ -0,0 +1,170 @@ +""" +Mappings between PostgreSQL and Python encodings. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import string +import codecs +from typing import Any, Dict, Optional, TYPE_CHECKING + +from .pq._enums import ConnStatus +from .errors import NotSupportedError +from ._compat import cache + +if TYPE_CHECKING: + from .pq.abc import PGconn + from .connection import BaseConnection + +OK = ConnStatus.OK + + +_py_codecs = { + "BIG5": "big5", + "EUC_CN": "gb2312", + "EUC_JIS_2004": "euc_jis_2004", + "EUC_JP": "euc_jp", + "EUC_KR": "euc_kr", + # "EUC_TW": not available in Python + "GB18030": "gb18030", + "GBK": "gbk", + "ISO_8859_5": "iso8859-5", + "ISO_8859_6": "iso8859-6", + "ISO_8859_7": "iso8859-7", + "ISO_8859_8": "iso8859-8", + "JOHAB": "johab", + "KOI8R": "koi8-r", + "KOI8U": "koi8-u", + "LATIN1": "iso8859-1", + "LATIN10": "iso8859-16", + "LATIN2": "iso8859-2", + "LATIN3": "iso8859-3", + "LATIN4": "iso8859-4", + "LATIN5": "iso8859-9", + "LATIN6": "iso8859-10", + "LATIN7": "iso8859-13", + "LATIN8": "iso8859-14", + "LATIN9": "iso8859-15", + # "MULE_INTERNAL": not available in Python + "SHIFT_JIS_2004": "shift_jis_2004", + "SJIS": "shift_jis", + # this actually means no encoding, see PostgreSQL docs + # it is special-cased by the text loader. + "SQL_ASCII": "ascii", + "UHC": "cp949", + "UTF8": "utf-8", + "WIN1250": "cp1250", + "WIN1251": "cp1251", + "WIN1252": "cp1252", + "WIN1253": "cp1253", + "WIN1254": "cp1254", + "WIN1255": "cp1255", + "WIN1256": "cp1256", + "WIN1257": "cp1257", + "WIN1258": "cp1258", + "WIN866": "cp866", + "WIN874": "cp874", +} + +py_codecs: Dict[bytes, str] = {} +py_codecs.update((k.encode(), v) for k, v in _py_codecs.items()) + +# Add an alias without underscore, for lenient lookups +py_codecs.update( + (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k +) + +pg_codecs = {v: k.encode() for k, v in _py_codecs.items()} + + +def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str: + """ + Return the Python encoding name of a psycopg connection. + + Default to utf8 if the connection has no encoding info. + """ + if not conn or conn.closed: + return "utf-8" + + pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def pgconn_encoding(pgconn: "PGconn") -> str: + """ + Return the Python encoding name of a libpq connection. + + Default to utf8 if the connection has no encoding info. + """ + if pgconn.status != OK: + return "utf-8" + + pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def conninfo_encoding(conninfo: str) -> str: + """ + Return the Python encoding name passed in a conninfo string. Default to utf8. + + Because the input is likely to come from the user and not normalised by the + server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars). + """ + from .conninfo import conninfo_to_dict + + params = conninfo_to_dict(conninfo) + pgenc = params.get("client_encoding") + if pgenc: + try: + return pg2pyenc(pgenc.encode()) + except NotSupportedError: + pass + + return "utf-8" + + +@cache +def py2pgenc(name: str) -> bytes: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise LookupError if the Python encoding is unknown. + """ + return pg_codecs[codecs.lookup(name).name] + + +@cache +def pg2pyenc(name: bytes) -> str: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise NotSupportedError if the PostgreSQL encoding is not supported by + Python. + """ + try: + return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()] + except KeyError: + sname = name.decode("utf8", "replace") + raise NotSupportedError(f"codec not available in Python: {sname!r}") + + +def _as_python_identifier(s: str, prefix: str = "f") -> str: + """ + Reduce a string to a valid Python identifier. + + Replace all non-valid chars with '_' and prefix the value with `!prefix` if + the first letter is an '_'. + """ + if not s.isidentifier(): + if s[0] in "1234567890": + s = prefix + s + if not s.isidentifier(): + s = _re_clean.sub("_", s) + # namedtuple fields cannot start with underscore. So... + if s[0] == "_": + s = prefix + s + return s + + +_re_clean = re.compile( + f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]" +) diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py new file mode 100644 index 0000000..a7cb78d --- /dev/null +++ b/psycopg/psycopg/_enums.py @@ -0,0 +1,79 @@ +""" +Enum values for psycopg + +These values are defined by us and are not necessarily dependent on +libpq-defined enums. +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import Enum, IntEnum +from selectors import EVENT_READ, EVENT_WRITE + +from . import pq + + +class Wait(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class Ready(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class PyFormat(str, Enum): + """ + Enum representing the format wanted for a query argument. + + The value `AUTO` allows psycopg to choose the best format for a certain + parameter. + """ + + __module__ = "psycopg.adapt" + + AUTO = "s" + """Automatically chosen (``%s`` placeholder).""" + TEXT = "t" + """Text parameter (``%t`` placeholder).""" + BINARY = "b" + """Binary parameter (``%b`` placeholder).""" + + @classmethod + def from_pq(cls, fmt: pq.Format) -> "PyFormat": + return _pg2py[fmt] + + @classmethod + def as_pq(cls, fmt: "PyFormat") -> pq.Format: + return _py2pg[fmt] + + +class IsolationLevel(IntEnum): + """ + Enum representing the isolation level for a transaction. + """ + + __module__ = "psycopg" + + READ_UNCOMMITTED = 1 + """:sql:`READ UNCOMMITTED` isolation level.""" + READ_COMMITTED = 2 + """:sql:`READ COMMITTED` isolation level.""" + REPEATABLE_READ = 3 + """:sql:`REPEATABLE READ` isolation level.""" + SERIALIZABLE = 4 + """:sql:`SERIALIZABLE` isolation level.""" + + +_py2pg = { + PyFormat.TEXT: pq.Format.TEXT, + PyFormat.BINARY: pq.Format.BINARY, +} + +_pg2py = { + pq.Format.TEXT: PyFormat.TEXT, + pq.Format.BINARY: PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py new file mode 100644 index 0000000..c818d86 --- /dev/null +++ b/psycopg/psycopg/_pipeline.py @@ -0,0 +1,288 @@ +""" +commands pipeline management +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +from types import TracebackType +from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from .abc import PipelineCommand, PQGen +from ._compat import Deque +from ._encodings import pgconn_encoding +from ._preparing import Key, Prepare +from .generators import pipeline_communicate, fetch_many, send + +if TYPE_CHECKING: + from .pq.abc import PGresult + from .cursor import BaseCursor + from .connection import BaseConnection, Connection + from .connection_async import AsyncConnection + + +PendingResult: TypeAlias = Union[ + None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]] +] + +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED +BAD = pq.ConnStatus.BAD + +ACTIVE = pq.TransactionStatus.ACTIVE + +logger = logging.getLogger("psycopg") + + +class BasePipeline: + + command_queue: Deque[PipelineCommand] + result_queue: Deque[PendingResult] + _is_supported: Optional[bool] = None + + def __init__(self, conn: "BaseConnection[Any]") -> None: + self._conn = conn + self.pgconn = conn.pgconn + self.command_queue = Deque[PipelineCommand]() + self.result_queue = Deque[PendingResult]() + self.level = 0 + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._conn.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def status(self) -> pq.PipelineStatus: + return pq.PipelineStatus(self.pgconn.pipeline_status) + + @classmethod + def is_supported(cls) -> bool: + """Return `!True` if the psycopg libpq wrapper supports pipeline mode.""" + if BasePipeline._is_supported is None: + BasePipeline._is_supported = not cls._not_supported_reason() + return BasePipeline._is_supported + + @classmethod + def _not_supported_reason(cls) -> str: + """Return the reason why the pipeline mode is not supported. + + Return an empty string if pipeline mode is supported. + """ + # Support only depends on the libpq functions available in the pq + # wrapper, not on the database version. + if pq.version() < 140000: + return ( + f"libpq too old {pq.version()};" + " v14 or greater required for pipeline mode" + ) + + if pq.__build_version__ < 140000: + return ( + f"libpq too old: module built for {pq.__build_version__};" + " v14 or greater required for pipeline mode" + ) + + return "" + + def _enter_gen(self) -> PQGen[None]: + if not self.is_supported(): + raise e.NotSupportedError( + f"pipeline mode not supported: {self._not_supported_reason()}" + ) + if self.level == 0: + self.pgconn.enter_pipeline_mode() + elif self.command_queue or self.pgconn.transaction_status == ACTIVE: + # Nested pipeline case. + # Transaction might be ACTIVE when the pipeline uses an "implicit + # transaction", typically in autocommit mode. But when entering a + # Psycopg transaction(), we expect the IDLE state. By sync()-ing, + # we make sure all previous commands are completed and the + # transaction gets back to IDLE. + yield from self._sync_gen() + self.level += 1 + + def _exit(self, exc: Optional[BaseException]) -> None: + self.level -= 1 + if self.level == 0 and self.pgconn.status != BAD: + try: + self.pgconn.exit_pipeline_mode() + except e.OperationalError as exc2: + # Notice that this error might be pretty irrecoverable. It + # happens on COPY, for instance: even if sync succeeds, exiting + # fails with "cannot exit pipeline mode with uncollected results" + if exc: + logger.warning("error ignored exiting %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + + def _sync_gen(self) -> PQGen[None]: + self._enqueue_sync() + yield from self._communicate_gen() + yield from self._fetch_gen(flush=False) + + def _exit_gen(self) -> PQGen[None]: + """ + Exit current pipeline by sending a Sync and fetch back all remaining results. + """ + try: + self._enqueue_sync() + yield from self._communicate_gen() + finally: + # No need to force flush since we emitted a sync just before. + yield from self._fetch_gen(flush=False) + + def _communicate_gen(self) -> PQGen[None]: + """Communicate with pipeline to send commands and possibly fetch + results, which are then processed. + """ + fetched = yield from pipeline_communicate(self.pgconn, self.command_queue) + to_process = [(self.result_queue.popleft(), results) for results in fetched] + for queued, results in to_process: + self._process_results(queued, results) + + def _fetch_gen(self, *, flush: bool) -> PQGen[None]: + """Fetch available results from the connection and process them with + pipeline queued items. + + If 'flush' is True, a PQsendFlushRequest() is issued in order to make + sure results can be fetched. Otherwise, the caller may emit a + PQpipelineSync() call to ensure the output buffer gets flushed before + fetching. + """ + if not self.result_queue: + return + + if flush: + self.pgconn.send_flush_request() + yield from send(self.pgconn) + + to_process = [] + while self.result_queue: + results = yield from fetch_many(self.pgconn) + if not results: + # No more results to fetch, but there may still be pending + # commands. + break + queued = self.result_queue.popleft() + to_process.append((queued, results)) + + for queued, results in to_process: + self._process_results(queued, results) + + def _process_results( + self, queued: PendingResult, results: List["PGresult"] + ) -> None: + """Process a results set fetched from the current pipeline. + + This matches 'results' with its respective element in the pipeline + queue. For commands (None value in the pipeline queue), results are + checked directly. For prepare statement creation requests, update the + cache. Otherwise, results are attached to their respective cursor. + """ + if queued is None: + (result,) = results + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + elif result.status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + cursor, prepinfo = queued + cursor._set_results_from_pipeline(results) + if prepinfo: + key, prep, name = prepinfo + # Update the prepare state of the query. + cursor._conn._prepared.validate(key, prep, name, results) + + def _enqueue_sync(self) -> None: + """Enqueue a PQpipelineSync() command.""" + self.command_queue.append(self.pgconn.pipeline_sync) + self.result_queue.append(None) + + +class Pipeline(BasePipeline): + """Handler for connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "Connection[Any]" + _Self = TypeVar("_Self", bound="Pipeline") + + def __init__(self, conn: "Connection[Any]") -> None: + super().__init__(conn) + + def sync(self) -> None: + """Sync the pipeline, send any pending command and receive and process + all available results. + """ + try: + with self._conn.lock: + self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + with self._conn.lock: + self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) + + +class AsyncPipeline(BasePipeline): + """Handler for async connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "AsyncConnection[Any]" + _Self = TypeVar("_Self", bound="AsyncPipeline") + + def __init__(self, conn: "AsyncConnection[Any]") -> None: + super().__init__(conn) + + async def sync(self) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py new file mode 100644 index 0000000..f60c0cb --- /dev/null +++ b/psycopg/psycopg/_preparing.py @@ -0,0 +1,198 @@ +""" +Support for prepared statements +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, auto +from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING +from collections import OrderedDict +from typing_extensions import TypeAlias + +from . import pq +from ._compat import Deque +from ._queries import PostgresQuery + +if TYPE_CHECKING: + from .pq.abc import PGresult + +Key: TypeAlias = Tuple[bytes, Tuple[int, ...]] + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + + +class Prepare(IntEnum): + NO = auto() + YES = auto() + SHOULD = auto() + + +class PrepareManager: + # Number of times a query is executed before it is prepared. + prepare_threshold: Optional[int] = 5 + + # Maximum number of prepared statements on the connection. + prepared_max: int = 100 + + def __init__(self) -> None: + # Map (query, types) to the number of times the query was seen. + self._counts: OrderedDict[Key, int] = OrderedDict() + + # Map (query, types) to the name of the statement if prepared. + self._names: OrderedDict[Key, bytes] = OrderedDict() + + # Counter to generate prepared statements names + self._prepared_idx = 0 + + self._maint_commands = Deque[bytes]() + + @staticmethod + def key(query: PostgresQuery) -> Key: + return (query.query, query.types) + + def get( + self, query: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + """ + Check if a query is prepared, tell back whether to prepare it. + """ + if prepare is False or self.prepare_threshold is None: + # The user doesn't want this query to be prepared + return Prepare.NO, b"" + + key = self.key(query) + name = self._names.get(key) + if name: + # The query was already prepared in this session + return Prepare.YES, name + + count = self._counts.get(key, 0) + if count >= self.prepare_threshold or prepare: + # The query has been executed enough times and needs to be prepared + name = f"_pg3_{self._prepared_idx}".encode() + self._prepared_idx += 1 + return Prepare.SHOULD, name + else: + # The query is not to be prepared yet + return Prepare.NO, b"" + + def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool: + """Check if we need to discard our entire state: it should happen on + rollback or on dropping objects, because the same object may get + recreated and postgres would fail internal lookups. + """ + if self._names or prep == Prepare.SHOULD: + for result in results: + if result.status != COMMAND_OK: + continue + cmdstat = result.command_status + if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"): + return self.clear() + return False + + @staticmethod + def _check_results(results: Sequence["PGresult"]) -> bool: + """Return False if 'results' are invalid for prepared statement cache.""" + if len(results) != 1: + # We cannot prepare a multiple statement + return False + + status = results[0].status + if COMMAND_OK != status != TUPLES_OK: + # We don't prepare failed queries or other weird results + return False + + return True + + def _rotate(self) -> None: + """Evict an old value from the cache. + + If it was prepared, deallocate it. Do it only once: if the cache was + resized, deallocate gradually. + """ + if len(self._counts) > self.prepared_max: + self._counts.popitem(last=False) + + if len(self._names) > self.prepared_max: + name = self._names.popitem(last=False)[1] + self._maint_commands.append(b"DEALLOCATE " + name) + + def maybe_add_to_cache( + self, query: PostgresQuery, prep: Prepare, name: bytes + ) -> Optional[Key]: + """Handle 'query' for possible addition to the cache. + + If a new entry has been added, return its key. Return None otherwise + (meaning the query is already in cache or cache is not enabled). + + Note: This method is only called in pipeline mode. + """ + # don't do anything if prepared statements are disabled + if self.prepare_threshold is None: + return None + + key = self.key(query) + if key in self._counts: + if prep is Prepare.SHOULD: + del self._counts[key] + self._names[key] = name + else: + self._counts[key] += 1 + self._counts.move_to_end(key) + return None + + elif key in self._names: + self._names.move_to_end(key) + return None + + else: + if prep is Prepare.SHOULD: + self._names[key] = name + else: + self._counts[key] = 1 + return key + + def validate( + self, + key: Key, + prep: Prepare, + name: bytes, + results: Sequence["PGresult"], + ) -> None: + """Validate cached entry with 'key' by checking query 'results'. + + Possibly return a command to perform maintenance on database side. + + Note: this method is only called in pipeline mode. + """ + if self._should_discard(prep, results): + return + + if not self._check_results(results): + self._names.pop(key, None) + self._counts.pop(key, None) + else: + self._rotate() + + def clear(self) -> bool: + """Clear the cache of the maintenance commands. + + Clear the internal state and prepare a command to clear the state of + the server. + """ + self._counts.clear() + if self._names: + self._names.clear() + self._maint_commands.clear() + self._maint_commands.append(b"DEALLOCATE ALL") + return True + else: + return False + + def get_maintenance_commands(self) -> Iterator[bytes]: + """ + Iterate over the commands needed to align the server state to our state + """ + while self._maint_commands: + yield self._maint_commands.popleft() diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py new file mode 100644 index 0000000..2a7554c --- /dev/null +++ b/psycopg/psycopg/_queries.py @@ -0,0 +1,375 @@ +""" +Utility module to manipulate queries +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional +from typing import Sequence, Tuple, Union, TYPE_CHECKING +from functools import lru_cache + +from . import pq +from . import errors as e +from .sql import Composable +from .abc import Buffer, Query, Params +from ._enums import PyFormat +from ._encodings import conn_encoding + +if TYPE_CHECKING: + from .abc import Transformer + + +class QueryPart(NamedTuple): + pre: bytes + item: Union[int, str] + format: PyFormat + + +class PostgresQuery: + """ + Helper to convert a Python query and parameters into Postgres format. + """ + + __slots__ = """ + query params types formats + _tx _want_formats _parts _encoding _order + """.split() + + def __init__(self, transformer: "Transformer"): + self._tx = transformer + + self.params: Optional[Sequence[Optional[Buffer]]] = None + # these are tuples so they can be used as keys e.g. in prepared stmts + self.types: Tuple[int, ...] = () + + # The format requested by the user and the ones to really pass Postgres + self._want_formats: Optional[List[PyFormat]] = None + self.formats: Optional[Sequence[pq.Format]] = None + + self._encoding = conn_encoding(transformer.connection) + self._parts: List[QueryPart] + self.query = b"" + self._order: Optional[List[str]] = None + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + ( + self.query, + self._want_formats, + self._order, + self._parts, + ) = _query2pg(bquery, self._encoding) + else: + self.query = bquery + self._want_formats = self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + assert self._want_formats is not None + self.params = self._tx.dump_sequence(params, self._want_formats) + self.types = self._tx.types or () + self.formats = self._tx.formats + else: + self.params = None + self.types = () + self.formats = None + + +class PostgresClientQuery(PostgresQuery): + """ + PostgresQuery subclass merging query and arguments client-side. + """ + + __slots__ = ("template",) + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + (self.template, self._order, self._parts) = _query2pg_client( + bquery, self._encoding + ) + else: + self.query = bquery + self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + self.params = tuple( + self._tx.as_literal(p) if p is not None else b"NULL" for p in params + ) + self.query = self.template % self.params + else: + self.params = None + + +@lru_cache() +def _query2pg( + query: bytes, encoding: str +) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into something Postgres understands. + + - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres + format (``$1``, ``$2``) + - placeholders can be %s, %t, or %b (auto, text or binary) + - return ``query`` (bytes), ``formats`` (list of formats) ``order`` + (sequence of names used in the query, in the position they appear) + ``parts`` (splits of queries and placeholders). + """ + parts = _split_query(query, encoding) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + formats = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"$%d" % (part.item + 1)) + formats.append(part.format) + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"$%d" % (len(seen) + 1) + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + formats.append(part.format) + else: + if seen[part.item][1] != part.format: + raise e.ProgrammingError( + f"placeholder '{part.item}' cannot have different formats" + ) + chunks.append(seen[part.item][0]) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), formats, order, parts + + +@lru_cache() +def _query2pg_client( + query: bytes, encoding: str +) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into a template to perform client-side binding + """ + parts = _split_query(query, encoding, collapse_double_percent=False) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"%s") + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"%s" + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + else: + chunks.append(seen[part.item][0]) + order.append(part.item) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), order, parts + + +def _validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] +) -> Sequence[Any]: + """ + Verify the compatibility between a query and a set of params. + """ + # Try concrete types, then abstract types + t = type(vars) + if t is list or t is tuple: + sequence = True + elif t is dict: + sequence = False + elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + sequence = True + elif isinstance(vars, Mapping): + sequence = False + else: + raise TypeError( + "query parameters should be a sequence or a mapping," + f" got {type(vars).__name__}" + ) + + if sequence: + if len(vars) != len(parts) - 1: + raise e.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0].item, int): + raise TypeError("named placeholders require a mapping of parameters") + return vars # type: ignore[return-value] + + else: + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + try: + return [vars[item] for item in order or ()] # type: ignore[call-overload] + except KeyError: + raise e.ProgrammingError( + "query parameter missing:" + f" {', '.join(sorted(i for i in order or () if i not in vars))}" + ) + + +_re_placeholder = re.compile( + rb"""(?x) + % # a literal % + (?: + (?: + \( ([^)]+) \) # or a name in (braces) + . # followed by a format + ) + | + (?:.) # or any char, really + ) + """ +) + + +def _split_query( + query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True +) -> List[QueryPart]: + parts: List[Tuple[bytes, Optional[Match[bytes]]]] = [] + cur = 0 + + # pairs [(fragment, match], with the last match None + m = None + for m in _re_placeholder.finditer(query): + pre = query[cur : m.span(0)[0]] + parts.append((pre, m)) + cur = m.span(0)[1] + if m: + parts.append((query[cur:], None)) + else: + parts.append((query, None)) + + rv = [] + + # drop the "%%", validate + i = 0 + phtype = None + while i < len(parts): + pre, m = parts[i] + if m is None: + # last part + rv.append(QueryPart(pre, 0, PyFormat.AUTO)) + break + + ph = m.group(0) + if ph == b"%%": + # unescape '%%' to '%' if necessary, then merge the parts + if collapse_double_percent: + ph = b"%" + pre1, m1 = parts[i + 1] + parts[i + 1] = (pre + ph + pre1, m1) + del parts[i] + continue + + if ph == b"%(": + raise e.ProgrammingError( + "incomplete placeholder:" + f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'" + ) + elif ph == b"% ": + # explicit messasge for a typical error + raise e.ProgrammingError( + "incomplete placeholder: '%'; if you want to use '%' as an" + " operator you can double it up, i.e. use '%%'" + ) + elif ph[-1:] not in b"sbt": + raise e.ProgrammingError( + "only '%s', '%b', '%t' are allowed as placeholders, got" + f" '{m.group(0).decode(encoding)}'" + ) + + # Index or name + item: Union[int, str] + item = m.group(1).decode(encoding) if m.group(1) else i + + if not phtype: + phtype = type(item) + elif phtype is not type(item): + raise e.ProgrammingError( + "positional and named placeholders cannot be mixed" + ) + + format = _ph_to_fmt[ph[-1:]] + rv.append(QueryPart(pre, item, format)) + i += 1 + + return rv + + +_ph_to_fmt = { + b"s": PyFormat.AUTO, + b"t": PyFormat.TEXT, + b"b": PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py new file mode 100644 index 0000000..28a6084 --- /dev/null +++ b/psycopg/psycopg/_struct.py @@ -0,0 +1,57 @@ +""" +Utility functions to deal with binary structs. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from typing import Callable, cast, Optional, Tuple +from typing_extensions import TypeAlias + +from .abc import Buffer +from . import errors as e +from ._compat import Protocol + +PackInt: TypeAlias = Callable[[int], bytes] +UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]] +PackFloat: TypeAlias = Callable[[float], bytes] +UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]] + + +class UnpackLen(Protocol): + def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: + ... + + +pack_int2 = cast(PackInt, struct.Struct("!h").pack) +pack_uint2 = cast(PackInt, struct.Struct("!H").pack) +pack_int4 = cast(PackInt, struct.Struct("!i").pack) +pack_uint4 = cast(PackInt, struct.Struct("!I").pack) +pack_int8 = cast(PackInt, struct.Struct("!q").pack) +pack_float4 = cast(PackFloat, struct.Struct("!f").pack) +pack_float8 = cast(PackFloat, struct.Struct("!d").pack) + +unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack) +unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack) +unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack) +unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack) +unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack) +unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack) +unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack) + +_struct_len = struct.Struct("!i") +pack_len = cast(Callable[[int], bytes], _struct_len.pack) +unpack_len = cast(UnpackLen, _struct_len.unpack_from) + + +def pack_float4_bug_304(x: float) -> bytes: + raise e.InterfaceError( + "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c" + " and psycopg-binary packages are not affected by this issue." + " See https://github.com/psycopg/psycopg/issues/304" + ) + + +# If issue #304 is detected, raise an error instead of dumping wrong data. +if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"): + pack_float4 = pack_float4_bug_304 diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py new file mode 100644 index 0000000..3528188 --- /dev/null +++ b/psycopg/psycopg/_tpc.py @@ -0,0 +1,116 @@ +""" +psycopg two-phase commit support +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +import datetime as dt +from base64 import b64encode, b64decode +from typing import Optional, Union +from dataclasses import dataclass, replace + +_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$") + + +@dataclass(frozen=True) +class Xid: + """A two-phase commit transaction identifier. + + The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`, + `bqual`). + + """ + + format_id: Optional[int] + gtrid: str + bqual: Optional[str] + prepared: Optional[dt.datetime] = None + owner: Optional[str] = None + database: Optional[str] = None + + @classmethod + def from_string(cls, s: str) -> "Xid": + """Try to parse an XA triple from the string. + + This may fail for several reasons. In such case return an unparsed Xid. + """ + try: + return cls._parse_string(s) + except Exception: + return Xid(None, s, None) + + def __str__(self) -> str: + return self._as_tid() + + def __len__(self) -> int: + return 3 + + def __getitem__(self, index: int) -> Union[int, str, None]: + return (self.format_id, self.gtrid, self.bqual)[index] + + @classmethod + def _parse_string(cls, s: str) -> "Xid": + m = _re_xid.match(s) + if not m: + raise ValueError("bad Xid format") + + format_id = int(m.group(1)) + gtrid = b64decode(m.group(2)).decode() + bqual = b64decode(m.group(3)).decode() + return cls.from_parts(format_id, gtrid, bqual) + + @classmethod + def from_parts( + cls, format_id: Optional[int], gtrid: str, bqual: Optional[str] + ) -> "Xid": + if format_id is not None: + if bqual is None: + raise TypeError("if format_id is specified, bqual must be too") + if not 0 <= format_id < 0x80000000: + raise ValueError("format_id must be a non-negative 32-bit integer") + if len(bqual) > 64: + raise ValueError("bqual must be not longer than 64 chars") + if len(gtrid) > 64: + raise ValueError("gtrid must be not longer than 64 chars") + + elif bqual is None: + raise TypeError("if format_id is None, bqual must be None too") + + return Xid(format_id, gtrid, bqual) + + def _as_tid(self) -> str: + """ + Return the PostgreSQL transaction_id for this XA xid. + + PostgreSQL wants just a string, while the DBAPI supports the XA + standard and thus a triple. We use the same conversion algorithm + implemented by JDBC in order to allow some form of interoperation. + + see also: the pgjdbc implementation + http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/ + postgresql/xa/RecoveredXid.java?rev=1.2 + """ + if self.format_id is None or self.bqual is None: + # Unparsed xid: return the gtrid. + return self.gtrid + + # XA xid: mash together the components. + egtrid = b64encode(self.gtrid.encode()).decode() + ebqual = b64encode(self.bqual.encode()).decode() + + return f"{self.format_id}_{egtrid}_{ebqual}" + + @classmethod + def _get_recover_query(cls) -> str: + return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts" + + @classmethod + def _from_record( + cls, gid: str, prepared: dt.datetime, owner: str, database: str + ) -> "Xid": + xid = Xid.from_string(gid) + return replace(xid, prepared=prepared, owner=owner, database=database) + + +Xid.__module__ = "psycopg" diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py new file mode 100644 index 0000000..19bd6ae --- /dev/null +++ b/psycopg/psycopg/_transform.py @@ -0,0 +1,350 @@ +""" +Helper object to transform values between Python and PostgreSQL +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import DefaultDict, TYPE_CHECKING +from collections import defaultdict +from typing_extensions import TypeAlias + +from . import pq +from . import postgres +from . import errors as e +from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType +from .rows import Row, RowMaker +from .postgres import INVALID_OID, TEXT_OID +from ._encodings import pgconn_encoding + +if TYPE_CHECKING: + from .abc import Dumper, Loader + from .adapt import AdaptersMap + from .pq.abc import PGresult + from .connection import BaseConnection + +DumperCache: TypeAlias = Dict[DumperKey, "Dumper"] +OidDumperCache: TypeAlias = Dict[int, "Dumper"] +LoaderCache: TypeAlias = Dict[int, "Loader"] + +TEXT = pq.Format.TEXT +PY_TEXT = PyFormat.TEXT + + +class Transformer(AdaptContext): + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that attributes + such as the server version or the connection encoding will not change. The + object have its state so adapting several values of the same type can be + optimised. + + """ + + __module__ = "psycopg.adapt" + + __slots__ = """ + types formats + _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid + _oid_dumpers _oid_types _row_dumpers _row_loaders + """.split() + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + _adapters: "AdaptersMap" + _pgresult: Optional["PGresult"] + _none_oid: int + + def __init__(self, context: Optional[AdaptContext] = None): + self._pgresult = self.types = self.formats = None + + # WARNING: don't store context, or you'll create a loop with the Cursor + if context: + self._adapters = context.adapters + self._conn = context.connection + else: + self._adapters = postgres.adapters + self._conn = None + + # mapping fmt, class -> Dumper instance + self._dumpers: DefaultDict[PyFormat, DumperCache] + self._dumpers = defaultdict(dict) + + # mapping fmt, oid -> Dumper instance + # Not often used, so create it only if needed. + self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]] + self._oid_dumpers = None + + # mapping fmt, oid -> Loader instance + self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {}) + + self._row_dumpers: Optional[List["Dumper"]] = None + + # sequence of load functions from value to python + # the length of the result columns + self._row_loaders: List[LoadFunc] = [] + + # mapping oid -> type sql representation + self._oid_types: Dict[int, bytes] = {} + + self._encoding = "" + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + """ + Return a Transformer from an AdaptContext. + + If the context is a Transformer instance, just return it. + """ + if isinstance(context, Transformer): + return context + else: + return cls(context) + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return self._conn + + @property + def encoding(self) -> str: + if not self._encoding: + conn = self.connection + self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" + return self._encoding + + @property + def adapters(self) -> "AdaptersMap": + return self._adapters + + @property + def pgresult(self) -> Optional["PGresult"]: + return self._pgresult + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None, + ) -> None: + self._pgresult = result + + if not result: + self._nfields = self._ntuples = 0 + if set_loaders: + self._row_loaders = [] + return + + self._ntuples = result.ntuples + nf = self._nfields = result.nfields + + if not set_loaders: + return + + if not nf: + self._row_loaders = [] + return + + fmt: pq.Format + fmt = result.fformat(0) if format is None else format # type: ignore + self._row_loaders = [ + self.get_loader(result.ftype(i), fmt).load for i in range(nf) + ] + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types] + self.types = tuple(types) + self.formats = [format] * len(types) + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_loaders = [self.get_loader(oid, format).load for oid in types] + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + nparams = len(params) + out: List[Optional[Buffer]] = [None] * nparams + + # If we have dumpers, it means set_dumper_types had been called, in + # which case self.types and self.formats are set to sequences of the + # right size. + if self._row_dumpers: + for i in range(nparams): + param = params[i] + if param is not None: + out[i] = self._row_dumpers[i].dump(param) + return out + + types = [self._get_none_oid()] * nparams + pqformats = [TEXT] * nparams + + for i in range(nparams): + param = params[i] + if param is None: + continue + dumper = self.get_dumper(param, formats[i]) + out[i] = dumper.dump(param) + types[i] = dumper.oid + pqformats[i] = dumper.format + + self.types = tuple(types) + self.formats = pqformats + + return out + + def as_literal(self, obj: Any) -> bytes: + dumper = self.get_dumper(obj, PY_TEXT) + rv = dumper.quote(obj) + # If the result is quoted, and the oid not unknown or text, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + oid = dumper.oid + if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID: + try: + type_sql = self._oid_types[oid] + except KeyError: + ti = self.adapters.types.get(oid) + if ti: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + else: + type_sql = b"" + self._oid_types[oid] = type_sql + + if type_sql: + rv = b"%s::%s" % (rv, type_sql) + + if not isinstance(rv, bytes): + rv = bytes(rv) + return rv + + def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Return a Dumper instance to dump `!obj`. + """ + # Normally, the type of the object dictates how to dump it + key = type(obj) + + # Reuse an existing Dumper class for objects of the same type + cache = self._dumpers[format] + try: + dumper = cache[key] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper(key, format) + cache[key] = dumper = dcls(key, self) + + # Check if the dumper requires an upgrade to handle this specific value + key1 = dumper.get_key(obj, format) + if key1 is key: + return dumper + + # If it does, ask the dumper to create its own upgraded version + try: + return cache[key1] + except KeyError: + dumper = cache[key1] = dumper.upgrade(obj, format) + return dumper + + def _get_none_oid(self) -> int: + try: + return self._none_oid + except AttributeError: + pass + + try: + rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid + except KeyError: + raise e.InterfaceError("None dumper not found") + + return rv + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper": + """ + Return a Dumper to dump an object to the type with given oid. + """ + if not self._oid_dumpers: + self._oid_dumpers = ({}, {}) + + # Reuse an existing Dumper class for objects of the same type + cache = self._oid_dumpers[format] + try: + return cache[oid] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper_by_oid(oid, format) + cache[oid] = dumper = dcls(NoneType, self) + + return dumper + + def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: + res = self._pgresult + if not res: + raise e.InterfaceError("result not set") + + if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): + raise e.InterfaceError( + f"rows must be included between 0 and {self._ntuples}" + ) + + records = [] + for row in range(row0, row1): + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + records.append(make_row(record)) + + return records + + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: + res = self._pgresult + if not res: + return None + + if not 0 <= row < self._ntuples: + return None + + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + + return make_row(record) + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + if len(self._row_loaders) != len(record): + raise e.ProgrammingError( + f"cannot load sequence of {len(record)} items:" + f" {len(self._row_loaders)} loaders registered" + ) + + return tuple( + (self._row_loaders[i](val) if val is not None else None) + for i, val in enumerate(record) + ) + + def get_loader(self, oid: int, format: pq.Format) -> "Loader": + try: + return self._loaders[format][oid] + except KeyError: + pass + + loader_cls = self._adapters.get_loader(oid, format) + if not loader_cls: + loader_cls = self._adapters.get_loader(INVALID_OID, format) + if not loader_cls: + raise e.InterfaceError("unknown oid loader not found") + loader = self._loaders[format][oid] = loader_cls(oid, self) + return loader diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py new file mode 100644 index 0000000..2f1a24d --- /dev/null +++ b/psycopg/psycopg/_typeinfo.py @@ -0,0 +1,461 @@ +""" +Information about PostgreSQL types + +These types allow to read information from the system catalog and provide +information to the adapters if needed. +""" + +# Copyright (C) 2020 The Psycopg Team +from enum import Enum +from typing import Any, Dict, Iterator, Optional, overload +from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import errors as e +from .abc import AdaptContext +from .rows import dict_row + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + from .sql import Identifier + +T = TypeVar("T", bound="TypeInfo") +RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]] + + +class TypeInfo: + """ + Hold information about a PostgreSQL base type. + """ + + __module__ = "psycopg.types" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + delimiter: str = ",", + ): + self.name = name + self.oid = oid + self.array_oid = array_oid + self.regtype = regtype or name + self.delimiter = delimiter + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__qualname__}:" + f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>" + ) + + @overload + @classmethod + def fetch( + cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"] + ) -> Optional[T]: + ... + + @overload + @classmethod + async def fetch( + cls: Type[T], + conn: "AsyncConnection[Any]", + name: Union[str, "Identifier"], + ) -> Optional[T]: + ... + + @classmethod + def fetch( + cls: Type[T], + conn: "Union[Connection[Any], AsyncConnection[Any]]", + name: Union[str, "Identifier"], + ) -> Any: + """Query a system catalog to read information about a type.""" + from .sql import Composable + from .connection_async import AsyncConnection + + if isinstance(name, Composable): + name = name.as_string(conn) + + if isinstance(conn, AsyncConnection): + return cls._fetch_async(conn, name) + + # This might result in a nested transaction. What we want is to leave + # the function with the connection in the state we found (either idle + # or intrans) + try: + with conn.transaction(): + with conn.cursor(binary=True, row_factory=dict_row) as cur: + cur.execute(cls._get_info_query(conn), {"name": name}) + recs = cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + async def _fetch_async( + cls: Type[T], conn: "AsyncConnection[Any]", name: str + ) -> Optional[T]: + """ + Query a system catalog to read information about a type. + + Similar to `fetch()` but can use an asynchronous connection. + """ + try: + async with conn.transaction(): + async with conn.cursor(binary=True, row_factory=dict_row) as cur: + await cur.execute(cls._get_info_query(conn), {"name": name}) + recs = await cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + def _from_records( + cls: Type[T], name: str, recs: Sequence[Dict[str, Any]] + ) -> Optional[T]: + if len(recs) == 1: + return cls(**recs[0]) + elif not recs: + return None + else: + raise e.ProgrammingError(f"found {len(recs)} different types named {name}") + + def register(self, context: Optional[AdaptContext] = None) -> None: + """ + Register the type information, globally or in the specified `!context`. + """ + if context: + types = context.adapters.types + else: + from . import postgres + + types = postgres.types + + types.add(self) + + if self.array_oid: + from .types.array import register_array + + register_array(self, context) + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + typname AS name, oid, typarray AS array_oid, + oid::regtype::text AS regtype, typdelim AS delimiter +FROM pg_type t +WHERE t.oid = %(name)s::regtype +ORDER BY t.oid +""" + + def _added(self, registry: "TypesRegistry") -> None: + """Method called by the `!registry` when the object is added there.""" + pass + + +class RangeInfo(TypeInfo): + """Manage information about a range type.""" + + __module__ = "psycopg.types.range" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngtypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map ranges subtypes to info + registry._registry[RangeInfo, self.subtype_oid] = self + + +class MultirangeInfo(TypeInfo): + """Manage information about a multirange type.""" + + # TODO: expose to multirange module once added + # __module__ = "psycopg.types.multirange" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + range_oid: int, + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.range_oid = range_oid + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + if conn.info.server_version < 140000: + raise e.NotSupportedError( + "multirange types are only available from PostgreSQL 14" + ) + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngmultitypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map multiranges ranges and subtypes to info + registry._registry[MultirangeInfo, self.range_oid] = self + registry._registry[MultirangeInfo, self.subtype_oid] = self + + +class CompositeInfo(TypeInfo): + """Manage information about a composite type.""" + + __module__ = "psycopg.types.composite" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + field_names: Sequence[str], + field_types: Sequence[int], + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.field_names = field_names + self.field_types = field_types + # Will be set by register() if the `factory` is a type + self.python_type: Optional[type] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + coalesce(a.fnames, '{}') AS field_names, + coalesce(a.ftypes, '{}') AS field_types +FROM pg_type t +LEFT JOIN ( + SELECT + attrelid, + array_agg(attname) AS fnames, + array_agg(atttypid) AS ftypes + FROM ( + SELECT a.attrelid, a.attname, a.atttypid + FROM pg_attribute a + JOIN pg_type t ON t.typrelid = a.attrelid + WHERE t.oid = %(name)s::regtype + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + ) x + GROUP BY attrelid +) a ON a.attrelid = t.typrelid +WHERE t.oid = %(name)s::regtype +""" + + +class EnumInfo(TypeInfo): + """Manage information about an enum type.""" + + __module__ = "psycopg.types.enum" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + labels: Sequence[str], + ): + super().__init__(name, oid, array_oid) + self.labels = labels + # Will be set by register_enum() + self.enum: Optional[Type[Enum]] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT name, oid, array_oid, array_agg(label) AS labels +FROM ( + SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + e.enumlabel AS label + FROM pg_type t + LEFT JOIN pg_enum e + ON e.enumtypid = t.oid + WHERE t.oid = %(name)s::regtype + ORDER BY e.enumsortorder +) x +GROUP BY name, oid, array_oid +""" + + +class TypesRegistry: + """ + Container for the information about types in a database. + """ + + __module__ = "psycopg.types" + + def __init__(self, template: Optional["TypesRegistry"] = None): + self._registry: Dict[RegistryKey, TypeInfo] + + # Make a shallow copy: it will become a proper copy if the registry + # is edited. + if template: + self._registry = template._registry + self._own_state = False + template._own_state = False + else: + self.clear() + + def clear(self) -> None: + self._registry = {} + self._own_state = True + + def add(self, info: TypeInfo) -> None: + self._ensure_own_state() + if info.oid: + self._registry[info.oid] = info + if info.array_oid: + self._registry[info.array_oid] = info + self._registry[info.name] = info + + if info.regtype and info.regtype not in self._registry: + self._registry[info.regtype] = info + + # Allow info to customise further their relation with the registry + info._added(self) + + def __iter__(self) -> Iterator[TypeInfo]: + seen = set() + for t in self._registry.values(): + if id(t) not in seen: + seen.add(id(t)) + yield t + + @overload + def __getitem__(self, key: Union[str, int]) -> TypeInfo: + ... + + @overload + def __getitem__(self, key: Tuple[Type[T], int]) -> T: + ... + + def __getitem__(self, key: RegistryKey) -> TypeInfo: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Raise KeyError if not found. + """ + if isinstance(key, str): + if key.endswith("[]"): + key = key[:-2] + elif not isinstance(key, (int, tuple)): + raise TypeError(f"the key must be an oid or a name, got {type(key)}") + try: + return self._registry[key] + except KeyError: + raise KeyError(f"couldn't find the type {key!r} in the types registry") + + @overload + def get(self, key: Union[str, int]) -> Optional[TypeInfo]: + ... + + @overload + def get(self, key: Tuple[Type[T], int]) -> Optional[T]: + ... + + def get(self, key: RegistryKey) -> Optional[TypeInfo]: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Unlike `__getitem__`, return None if not found. + """ + try: + return self[key] + except KeyError: + return None + + def get_oid(self, name: str) -> int: + """ + Return the oid of a PostgreSQL type by name. + + :param key: the name of the type to look for. + + Return the array oid if the type ends with "``[]``" + + Raise KeyError if the name is unknown. + """ + t = self[name] + if name.endswith("[]"): + return t.array_oid + else: + return t.oid + + def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]: + """ + Return info about a `TypeInfo` subclass by its element name or oid. + + :param cls: the subtype of `!TypeInfo` to look for. Currently + supported are `~psycopg.types.range.RangeInfo` and + `~psycopg.types.multirange.MultirangeInfo`. + :param subtype: The name or OID of the subtype of the element to look for. + :return: The `!TypeInfo` object of class `!cls` whose subtype is + `!subtype`. `!None` if the element or its range are not found. + """ + try: + info = self[subtype] + except KeyError: + return None + return self.get((cls, info.oid)) + + def _ensure_own_state(self) -> None: + # Time to write! so, copy. + if not self._own_state: + self._registry = self._registry.copy() + self._own_state = True diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py new file mode 100644 index 0000000..813ed62 --- /dev/null +++ b/psycopg/psycopg/_tz.py @@ -0,0 +1,44 @@ +""" +Timezone utility functions. +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import Dict, Optional, Union +from datetime import timezone, tzinfo + +from .pq.abc import PGconn +from ._compat import ZoneInfo + +logger = logging.getLogger("psycopg") + +_timezones: Dict[Union[None, bytes], tzinfo] = { + None: timezone.utc, + b"UTC": timezone.utc, +} + + +def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo: + """Return the Python timezone info of the connection's timezone.""" + tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None + try: + return _timezones[tzname] + except KeyError: + sname = tzname.decode() if tzname else "UTC" + try: + zi: tzinfo = ZoneInfo(sname) + except (KeyError, OSError): + logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname) + zi = timezone.utc + except Exception as ex: + logger.warning( + "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)", + sname, + type(ex).__name__, + ex, + ) + zi = timezone.utc + + _timezones[tzname] = zi + return zi diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py new file mode 100644 index 0000000..f861741 --- /dev/null +++ b/psycopg/psycopg/_wrappers.py @@ -0,0 +1,137 @@ +""" +Wrappers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Wrappers to force numbers to be cast as specific PostgreSQL types + +# These types are implemented here but exposed by `psycopg.types.numeric`. +# They are defined here to avoid a circular import. +_MODULE = "psycopg.types.numeric" + + +class Int2(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int2": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int4(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int8(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class IntNumeric(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "IntNumeric": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float4(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float8(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Oid(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`oid`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Oid": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py new file mode 100644 index 0000000..80c8fbf --- /dev/null +++ b/psycopg/psycopg/abc.py @@ -0,0 +1,266 @@ +""" +Protocol objects representing different implementations of the same classes. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Generator, Mapping +from typing import List, Optional, Sequence, Tuple, TypeVar, Union +from typing import TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from ._enums import PyFormat as PyFormat +from ._compat import Protocol, LiteralString + +if TYPE_CHECKING: + from . import sql + from .rows import Row, RowMaker + from .pq.abc import PGresult + from .waiting import Wait, Ready + from .connection import BaseConnection + from ._adapters_map import AdaptersMap + +NoneType: type = type(None) + +# An object implementing the buffer protocol +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + +Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"] +Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] +ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") +PipelineCommand: TypeAlias = Callable[[], None] +DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]] + +# Waiting protocol types + +RV = TypeVar("RV") + +PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV] +"""Generator for processes where the connection file number can change. + +This can happen in connection and reset, but not in normal querying. +""" + +PQGen: TypeAlias = Generator["Wait", "Ready", RV] +"""Generator for processes where the connection file number won't change. +""" + + +class WaitFunc(Protocol): + """ + Wait on the connection which generated `PQgen` and return its final result. + """ + + def __call__( + self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None + ) -> RV: + ... + + +# Adaptation types + +DumpFunc: TypeAlias = Callable[[Any], Buffer] +LoadFunc: TypeAlias = Callable[[Buffer], Any] + + +class AdaptContext(Protocol): + """ + A context describing how types are adapted. + + Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`, + `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`. + + Note that this is a `~typing.Protocol`, so objects implementing + `!AdaptContext` don't need to explicitly inherit from this class. + + """ + + @property + def adapters(self) -> "AdaptersMap": + """The adapters configuration that this object uses.""" + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + """The connection used by this object, if available. + + :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None` + """ + ... + + +class Dumper(Protocol): + """ + Convert Python objects of type `!cls` to PostgreSQL representation. + """ + + format: pq.Format + """ + The format that this class `dump()` method produces, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + oid: int + """The oid to pass to the server, if known; 0 otherwise (class attribute).""" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + ... + + def dump(self, obj: Any) -> Buffer: + """Convert the object `!obj` to PostgreSQL representation. + + :param obj: the object to convert. + """ + ... + + def quote(self, obj: Any) -> Buffer: + """Convert the object `!obj` to escaped representation. + + :param obj: the object to convert. + """ + ... + + def get_key(self, obj: Any, format: PyFormat) -> DumperKey: + """Return an alternative key to upgrade the dumper to represent `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Normally the type of the object is all it takes to define how to dump + the object to the database. For instance, a Python `~datetime.date` can + be simply converted into a PostgreSQL :sql:`date`. + + In a few cases, just the type is not enough. For example: + + - A Python `~datetime.datetime` could be represented as a + :sql:`timestamptz` or a :sql:`timestamp`, according to whether it + specifies a `!tzinfo` or not. + + - A Python int could be stored as several Postgres types: int2, int4, + int8, numeric. If a type too small is used, it may result in an + overflow. If a type too large is used, PostgreSQL may not want to + cast it to a smaller type. + + - Python lists should be dumped according to the type they contain to + convert them to e.g. array of strings, array of ints (and which + size of int?...) + + In these cases, a dumper can implement `!get_key()` and return a new + class, or sequence of classes, that can be used to identify the same + dumper again. If the mechanism is not needed, the method should return + the same `!cls` object passed in the constructor. + + If a dumper implements `get_key()` it should also implement + `upgrade()`. + + """ + ... + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """Return a new dumper to manage `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Once `Transformer.get_dumper()` has been notified by `get_key()` that + this Dumper class cannot handle `!obj` itself, it will invoke + `!upgrade()`, which should return a new `Dumper` instance, which will + be reused for every objects for which `!get_key()` returns the same + result. + """ + ... + + +class Loader(Protocol): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format + """ + The format that this class `load()` method can convert, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + ... + + def load(self, data: Buffer) -> Any: + """ + Convert the data returned by the database into a Python object. + + :param data: the data to convert. + """ + ... + + +class Transformer(Protocol): + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + def __init__(self, context: Optional[AdaptContext] = None): + ... + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + ... + + @property + def encoding(self) -> str: + ... + + @property + def adapters(self) -> "AdaptersMap": + ... + + @property + def pgresult(self) -> Optional["PGresult"]: + ... + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None + ) -> None: + ... + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + ... + + def as_literal(self, obj: Any) -> bytes: + ... + + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: + ... + + def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]: + ... + + def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: + ... + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + ... + + def get_loader(self, oid: int, format: pq.Format) -> Loader: + ... diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py new file mode 100644 index 0000000..7ec4a55 --- /dev/null +++ b/psycopg/psycopg/adapt.py @@ -0,0 +1,162 @@ +""" +Entry point into the adaptation system. +""" + +# Copyright (C) 2020 The Psycopg Team + +from abc import ABC, abstractmethod +from typing import Any, Optional, Type, TYPE_CHECKING + +from . import pq, abc +from . import _adapters_map +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg + +if TYPE_CHECKING: + from .connection import BaseConnection + +AdaptersMap = _adapters_map.AdaptersMap +Buffer = abc.Buffer + +ORD_BS = ord("\\") + + +class Dumper(abc.Dumper, ABC): + """ + Convert Python object of the type `!cls` to PostgreSQL representation. + """ + + oid: int = 0 + """The oid to pass to the server, if known.""" + + format: pq.Format = pq.Format.TEXT + """The format of the data dumped.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + self.cls = cls + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + def __repr__(self) -> str: + return ( + f"<{type(self).__module__}.{type(self).__qualname__}" + f" (oid={self.oid}) at 0x{id(self):x}>" + ) + + @abstractmethod + def dump(self, obj: Any) -> Buffer: + ... + + def quote(self, obj: Any) -> Buffer: + """ + By default return the `dump()` value quoted and sanitised, so + that the result can be used to build a SQL string. This works well + for most types and you won't likely have to implement this method in a + subclass. + """ + value = self.dump(obj) + + if self.connection: + esc = pq.Escaping(self.connection.pgconn) + # escaping and quoting + return esc.escape_literal(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + # No quoting, only quote escaping, random bs escaping. See further. + esc = pq.Escaping() + out = esc.escape_string(value) + + # b"\\" in memoryview doesn't work so search for the ascii value + if ORD_BS not in out: + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + return b"'" + out + b"'" + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + rv: bytes = b" E'" + out + b"'" + if esc.escape_string(b"\\") == b"\\": + rv = rv.replace(b"\\", b"\\\\") + return rv + + def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey: + """ + Implementation of the `~psycopg.abc.Dumper.get_key()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation returns the `!cls` passed in the constructor. + Subclasses needing to specialise the PostgreSQL type according to the + *value* of the object dumped (not only according to to its type) + should override this class. + + """ + return self.cls + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation just returns `!self`. If a subclass implements + `get_key()` it should probably override `!upgrade()` too. + """ + return self + + +class Loader(abc.Loader, ABC): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format = pq.Format.TEXT + """The format of the data loaded.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + self.oid = oid + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + @abstractmethod + def load(self, data: Buffer) -> Any: + """Convert a PostgreSQL value to a Python object.""" + ... + + +Transformer: Type["abc.Transformer"] + +# Override it with fast object if available +if _psycopg: + Transformer = _psycopg.Transformer +else: + from . import _transform + + Transformer = _transform.Transformer + + +class RecursiveDumper(Dumper): + """Dumper with a transformer to help dumping recursive types.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self._tx = Transformer.from_context(context) + + +class RecursiveLoader(Loader): + """Loader with a transformer to help loading recursive types.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer.from_context(context) diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py new file mode 100644 index 0000000..6271ec5 --- /dev/null +++ b/psycopg/psycopg/client_cursor.py @@ -0,0 +1,95 @@ +""" +psycopg client-side binding cursors +""" + +# Copyright (C) 2022 The Psycopg Team + +from typing import Optional, Tuple, TYPE_CHECKING +from functools import partial + +from ._queries import PostgresQuery, PostgresClientQuery + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params +from .rows import Row +from .cursor import BaseCursor, Cursor +from ._preparing import Prepare +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from typing import Any # noqa: F401 + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + + +class ClientCursorMixin(BaseCursor[ConnectionType, Row]): + def mogrify(self, query: Query, params: Optional[Params] = None) -> str: + """ + Return the query and parameters merged. + + Parameters are adapted and merged to the query the same way that + `!execute()` would do. + + """ + self._tx = adapt.Transformer(self) + pgq = self._convert_query(query, params) + return pgq.query.decode(self._tx.encoding) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if fmt == BINARY: + raise e.NotSupportedError( + "client-side cursors don't support binary results" + ) + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial(self._pgconn.send_query_params, query.query, None) + ) + elif force_extended: + self._pgconn.send_query_params(query.query, None) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresClientQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return (Prepare.NO, b"") + + +class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + + +class AsyncClientCursor( + ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py new file mode 100644 index 0000000..78ad577 --- /dev/null +++ b/psycopg/psycopg/connection.py @@ -0,0 +1,1031 @@ +""" +psycopg connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +import threading +from types import TracebackType +from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator +from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union +from typing import overload, TYPE_CHECKING +from weakref import ref, ReferenceType +from warnings import warn +from functools import partial +from contextlib import contextmanager +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from . import waiting +from . import postgres +from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import PQGen, PQGenConn +from .sql import Composable, SQL +from ._tpc import Xid +from .rows import Row, RowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .cursor import Cursor +from ._compat import LiteralString +from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from ._pipeline import BasePipeline, Pipeline +from .generators import notifies, connect, execute +from ._encodings import pgconn_encoding +from ._preparing import PrepareManager +from .transaction import Transaction +from .server_cursor import ServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn, PGresult + from psycopg_pool.base import BasePool + + +# Row Type variable for Cursor (when it needs to be distinguished from the +# connection's one) +CursorRow = TypeVar("CursorRow") + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class Notify(NamedTuple): + """An asynchronous notification received from the database.""" + + channel: str + """The name of the channel on which the notification was received.""" + + payload: str + """The message attached to the notification.""" + + pid: int + """The PID of the backend process which sent the notification.""" + + +Notify.__module__ = "psycopg" + +NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] +NotifyHandler: TypeAlias = Callable[[Notify], None] + + +class BaseConnection(Generic[Row]): + """ + Base class for different types of connections. + + Share common functionalities such as access to the wrapped PGconn, but + allow different interfaces (sync/async). + """ + + # DBAPI2 exposed exceptions + Warning = e.Warning + Error = e.Error + InterfaceError = e.InterfaceError + DatabaseError = e.DatabaseError + DataError = e.DataError + OperationalError = e.OperationalError + IntegrityError = e.IntegrityError + InternalError = e.InternalError + ProgrammingError = e.ProgrammingError + NotSupportedError = e.NotSupportedError + + # Enums useful for the connection + ConnStatus = pq.ConnStatus + TransactionStatus = pq.TransactionStatus + + def __init__(self, pgconn: "PGconn"): + self.pgconn = pgconn + self._autocommit = False + + # None, but set to a copy of the global adapters map as soon as requested. + self._adapters: Optional[AdaptersMap] = None + + self._notice_handlers: List[NoticeHandler] = [] + self._notify_handlers: List[NotifyHandler] = [] + + # Number of transaction blocks currently entered + self._num_transactions = 0 + + self._closed = False # closed by an explicit close() + self._prepared: PrepareManager = PrepareManager() + self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared + + wself = ref(self) + pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) + pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) + + # Attribute is only set if the connection is from a pool so we can tell + # apart a connection in the pool too (when _pool = None) + self._pool: Optional["BasePool[Any]"] + + self._pipeline: Optional[BasePipeline] = None + + # Time after which the connection should be closed + self._expire_at: float + + self._isolation_level: Optional[IsolationLevel] = None + self._read_only: Optional[bool] = None + self._deferrable: Optional[bool] = None + self._begin_statement = b"" + + def __del__(self) -> None: + # If fails on connection we might not have this attribute yet + if not hasattr(self, "pgconn"): + return + + # Connection correctly closed + if self.closed: + return + + # Connection in a pool so terminating with the program is normal + if hasattr(self, "_pool"): + return + + warn( + f"connection {self} was deleted while still open." + " Please use 'with' or '.close()' to close the connection", + ResourceWarning, + ) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def closed(self) -> bool: + """`!True` if the connection is closed.""" + return self.pgconn.status == BAD + + @property + def broken(self) -> bool: + """ + `!True` if the connection was interrupted. + + A broken connection is always `closed`, but wasn't closed in a clean + way, such as using `close()` or a `!with` block. + """ + return self.pgconn.status == BAD and not self._closed + + @property + def autocommit(self) -> bool: + """The autocommit state of the connection.""" + return self._autocommit + + @autocommit.setter + def autocommit(self, value: bool) -> None: + self._set_autocommit(value) + + def _set_autocommit(self, value: bool) -> None: + raise NotImplementedError + + def _set_autocommit_gen(self, value: bool) -> PQGen[None]: + yield from self._check_intrans_gen("autocommit") + self._autocommit = bool(value) + + @property + def isolation_level(self) -> Optional[IsolationLevel]: + """ + The isolation level of the new transactions started on the connection. + """ + return self._isolation_level + + @isolation_level.setter + def isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._set_isolation_level(value) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + raise NotImplementedError + + def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]: + yield from self._check_intrans_gen("isolation_level") + self._isolation_level = IsolationLevel(value) if value is not None else None + self._begin_statement = b"" + + @property + def read_only(self) -> Optional[bool]: + """ + The read-only state of the new transactions started on the connection. + """ + return self._read_only + + @read_only.setter + def read_only(self, value: Optional[bool]) -> None: + self._set_read_only(value) + + def _set_read_only(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("read_only") + self._read_only = bool(value) + self._begin_statement = b"" + + @property + def deferrable(self) -> Optional[bool]: + """ + The deferrable state of the new transactions started on the connection. + """ + return self._deferrable + + @deferrable.setter + def deferrable(self, value: Optional[bool]) -> None: + self._set_deferrable(value) + + def _set_deferrable(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("deferrable") + self._deferrable = bool(value) + self._begin_statement = b"" + + def _check_intrans_gen(self, attribute: str) -> PQGen[None]: + # Raise an exception if we are in a transaction + status = self.pgconn.transaction_status + if status == IDLE and self._pipeline: + yield from self._pipeline._sync_gen() + status = self.pgconn.transaction_status + if status != IDLE: + if self._num_transactions: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection.transaction() context in progress" + ) + else: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection in transaction status " + f"{pq.TransactionStatus(status).name}" + ) + + @property + def info(self) -> ConnectionInfo: + """A `ConnectionInfo` attribute to inspect connection properties.""" + return ConnectionInfo(self.pgconn) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + self._adapters = AdaptersMap(postgres.adapters) + + return self._adapters + + @property + def connection(self) -> "BaseConnection[Row]": + # implement the AdaptContext protocol + return self + + def fileno(self) -> int: + """Return the file descriptor of the connection. + + This function allows to use the connection as file-like object in + functions waiting for readiness, such as the ones defined in the + `selectors` module. + """ + return self.pgconn.socket + + def cancel(self) -> None: + """Cancel the current operation on the connection.""" + # No-op if the connection is closed + # this allows to use the method as callback handler without caring + # about its life. + if self.closed: + return + + if self._tpc and self._tpc[1]: + raise e.ProgrammingError( + "cancel() cannot be used with a prepared two-phase transaction" + ) + + c = self.pgconn.get_cancel() + c.cancel() + + def add_notice_handler(self, callback: NoticeHandler) -> None: + """ + Register a callable to be invoked when a notice message is received. + + :param callback: the callback to call upon message received. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.append(callback) + + def remove_notice_handler(self, callback: NoticeHandler) -> None: + """ + Unregister a notice message callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.remove(callback) + + @staticmethod + def _notice_handler( + wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" + ) -> None: + self = wself() + if not (self and self._notice_handlers): + return + + diag = e.Diagnostic(res, pgconn_encoding(self.pgconn)) + for cb in self._notice_handlers: + try: + cb(diag) + except Exception as ex: + logger.exception("error processing notice callback '%s': %s", cb, ex) + + def add_notify_handler(self, callback: NotifyHandler) -> None: + """ + Register a callable to be invoked whenever a notification is received. + + :param callback: the callback to call upon notification received. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.append(callback) + + def remove_notify_handler(self, callback: NotifyHandler) -> None: + """ + Unregister a notification callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.remove(callback) + + @staticmethod + def _notify_handler( + wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify + ) -> None: + self = wself() + if not (self and self._notify_handlers): + return + + enc = pgconn_encoding(self.pgconn) + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + for cb in self._notify_handlers: + cb(n) + + @property + def prepare_threshold(self) -> Optional[int]: + """ + Number of times a query is executed before it is prepared. + + - If it is set to 0, every query is prepared the first time it is + executed. + - If it is set to `!None`, prepared statements are disabled on the + connection. + + Default value: 5 + """ + return self._prepared.prepare_threshold + + @prepare_threshold.setter + def prepare_threshold(self, value: Optional[int]) -> None: + self._prepared.prepare_threshold = value + + @property + def prepared_max(self) -> int: + """ + Maximum number of prepared statements on the connection. + + Default value: 100 + """ + return self._prepared.prepared_max + + @prepared_max.setter + def prepared_max(self, value: int) -> None: + self._prepared.prepared_max = value + + # Generators to perform high-level operations on the connection + # + # These operations are expressed in terms of non-blocking generators + # and the task of waiting when needed (when the generators yield) is left + # to the connections subclass, which might wait either in blocking mode + # or through asyncio. + # + # All these generators assume exclusive access to the connection: subclasses + # should have a lock and hold it before calling and consuming them. + + @classmethod + def _connect_gen( + cls: Type[ConnectionType], + conninfo: str = "", + *, + autocommit: bool = False, + ) -> PQGenConn[ConnectionType]: + """Generator to connect to the database and create a new instance.""" + pgconn = yield from connect(conninfo) + conn = cls(pgconn) + conn._autocommit = bool(autocommit) + return conn + + def _exec_command( + self, command: Query, result_format: pq.Format = TEXT + ) -> PQGen[Optional["PGresult"]]: + """ + Generator to send a command and receive the result to the backend. + + Only used to implement internal commands such as "commit", with eventual + arguments bound client-side. The cursor can do more complex stuff. + """ + self._check_connection_ok() + + if isinstance(command, str): + command = command.encode(pgconn_encoding(self.pgconn)) + elif isinstance(command, Composable): + command = command.as_bytes(self) + + if self._pipeline: + cmd = partial( + self.pgconn.send_query_params, + command, + None, + result_format=result_format, + ) + self._pipeline.command_queue.append(cmd) + self._pipeline.result_queue.append(None) + return None + + self.pgconn.send_query_params(command, None, result_format=result_format) + + result = (yield from execute(self.pgconn))[-1] + if result.status != COMMAND_OK and result.status != TUPLES_OK: + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + else: + raise e.InterfaceError( + f"unexpected result {pq.ExecStatus(result.status).name}" + f" from command {command.decode()!r}" + ) + return result + + def _check_connection_ok(self) -> None: + if self.pgconn.status == OK: + return + + if self.pgconn.status == BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + "cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + + def _start_query(self) -> PQGen[None]: + """Generator to start a transaction if necessary.""" + if self._autocommit: + return + + if self.pgconn.transaction_status != IDLE: + return + + yield from self._exec_command(self._get_tx_start_command()) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _get_tx_start_command(self) -> bytes: + if self._begin_statement: + return self._begin_statement + + parts = [b"BEGIN"] + + if self.isolation_level is not None: + val = IsolationLevel(self.isolation_level) + parts.append(b"ISOLATION LEVEL") + parts.append(val.name.replace("_", " ").encode()) + + if self.read_only is not None: + parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") + + if self.deferrable is not None: + parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") + + self._begin_statement = b" ".join(parts) + return self._begin_statement + + def _commit_gen(self) -> PQGen[None]: + """Generator implementing `Connection.commit()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "commit() cannot be used during a two-phase transaction" + ) + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"COMMIT") + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _rollback_gen(self) -> PQGen[None]: + """Generator implementing `Connection.rollback()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "rollback() cannot be used during a two-phase transaction" + ) + + # Get out of a "pipeline aborted" state + if self._pipeline: + yield from self._pipeline._sync_gen() + + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"ROLLBACK") + self._prepared.clear() + for cmd in self._prepared.get_maintenance_commands(): + yield from self._exec_command(cmd) + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: + """ + Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. + + The argument types and constraints are explained in + :ref:`two-phase-commit`. + + The values passed to the method will be available on the returned + object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. + """ + self._check_tpc() + return Xid.from_parts(format_id, gtrid, bqual) + + def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + self._check_tpc() + + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self.pgconn.transaction_status != IDLE: + raise e.ProgrammingError( + "can't start two-phase transaction: connection in status" + f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" + ) + + if self._autocommit: + raise e.ProgrammingError( + "can't use two-phase transactions in autocommit mode" + ) + + self._tpc = (xid, False) + yield from self._exec_command(self._get_tx_start_command()) + + def _tpc_prepare_gen(self) -> PQGen[None]: + if not self._tpc: + raise e.ProgrammingError( + "'tpc_prepare()' must be called inside a two-phase transaction" + ) + if self._tpc[1]: + raise e.ProgrammingError( + "'tpc_prepare()' cannot be used during a prepared two-phase transaction" + ) + xid = self._tpc[0] + self._tpc = (xid, True) + yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _tpc_finish_gen( + self, action: LiteralString, xid: Union[Xid, str, None] + ) -> PQGen[None]: + fname = f"tpc_{action.lower()}()" + if xid is None: + if not self._tpc: + raise e.ProgrammingError( + f"{fname} without xid must must be" + " called inside a two-phase transaction" + ) + xid = self._tpc[0] + else: + if self._tpc: + raise e.ProgrammingError( + f"{fname} with xid must must be called" + " outside a two-phase transaction" + ) + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self._tpc and not self._tpc[1]: + meth: Callable[[], PQGen[None]] + meth = getattr(self, f"_{action.lower()}_gen") + self._tpc = None + yield from meth() + else: + yield from self._exec_command( + SQL("{} PREPARED {}").format(SQL(action), str(xid)) + ) + self._tpc = None + + def _check_tpc(self) -> None: + """Raise NotSupportedError if TPC is not supported.""" + # TPC supported on every supported PostgreSQL version. + pass + + +class Connection(BaseConnection[Row]): + """ + Wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[Cursor[Row]] + server_cursor_factory: Type[ServerCursor[Row]] + row_factory: RowFactory[Row] + _pipeline: Optional[Pipeline] + _Self = TypeVar("_Self", bound="Connection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = threading.Lock() + self.cursor_factory = Cursor + self.server_cursor_factory = ServerCursor + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: Optional[RowFactory[Row]] = None, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Any, + ) -> "Connection[Any]": + """ + Connect to a database server and return a new `Connection` instance. + """ + params = cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + self.close() + + @classmethod + def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + :param conninfo: Connection string as received by `~Connection.connect()`. + :param kwargs: Overriding connection arguments as received by `!connect()`. + :return: Connection arguments merged and eventually modified, in a + format similar to `~conninfo.conninfo_to_dict()`. + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + return params + + def close(self) -> None: + """Close the database connection.""" + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> Cursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + ) -> Cursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: RowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[RowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[Cursor[Any], ServerCursor[Any]]: + """ + Return a new cursor to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[Cursor[Any], ServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> Cursor[Row]: + """Execute a query and return a cursor to read its results.""" + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + with self.lock: + self.wait(self._commit_gen()) + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + with self.lock: + self.wait(self._rollback_gen()) + + @contextmanager + def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> Iterator[Transaction]: + """ + Start a context block with a new transaction or nested transaction. + + :param savepoint_name: Name of the savepoint used to manage a nested + transaction. If `!None`, one will be chosen automatically. + :param force_rollback: Roll back the transaction at the end of the + block even if there were no error (e.g. to try a no-op process). + :rtype: Transaction + """ + tx = Transaction(self, savepoint_name, force_rollback) + if self._pipeline: + with self.pipeline(), tx, self.pipeline(): + yield tx + else: + with tx: + yield tx + + def notifies(self) -> Generator[Notify, None, None]: + """ + Yield `Notify` objects as soon as they are received from the database. + """ + while True: + with self.lock: + try: + ns = self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @contextmanager + def pipeline(self) -> Iterator[Pipeline]: + """Switch the connection into pipeline mode.""" + with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = Pipeline(self) + + try: + with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + """ + Consume a generator operating on the connection. + + The function must be used on generators that don't change connection + fd (i.e. not on connect and reset). + """ + try: + return waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except KeyboardInterrupt: + # On Ctrl-C, try to cancel the query in the server, otherwise + # the connection will remain stuck in ACTIVE state. + c = self.pgconn.get_cancel() + c.cancel() + try: + waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + """Consume a connection generator.""" + return waiting.wait_conn(gen, timeout=timeout) + + def _set_autocommit(self, value: bool) -> None: + with self.lock: + self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + with self.lock: + self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_deferrable_gen(value)) + + def tpc_begin(self, xid: Union[Xid, str]) -> None: + """ + Begin a TPC transaction with the given transaction ID `!xid`. + """ + with self.lock: + self.wait(self._tpc_begin_gen(xid)) + + def tpc_prepare(self) -> None: + """ + Perform the first phase of a transaction started with `tpc_begin()`. + """ + try: + with self.lock: + self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + """ + Commit a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("COMMIT", xid)) + + def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + """ + Roll back a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("ROLLBACK", xid)) + + def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + cur.execute(Xid._get_recover_query()) + res = cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + self.rollback() + + return res diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py new file mode 100644 index 0000000..aa02dc0 --- /dev/null +++ b/psycopg/psycopg/connection_async.py @@ -0,0 +1,436 @@ +""" +psycopg async connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import asyncio +import logging +from types import TracebackType +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from . import waiting +from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from ._tpc import Xid +from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async +from ._pipeline import AsyncPipeline +from ._encodings import pgconn_encoding +from .connection import BaseConnection, CursorRow, Notify +from .generators import notifies +from .transaction import AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class AsyncConnection(BaseConnection[Row]): + """ + Asynchronous wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[AsyncCursor[Row]] + server_cursor_factory: Type[AsyncServerCursor[Row]] + row_factory: AsyncRowFactory[Row] + _pipeline: Optional[AsyncPipeline] + _Self = TypeVar("_Self", bound="AsyncConnection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = asyncio.Lock() + self.cursor_factory = AsyncCursor + self.server_cursor_factory = AsyncServerCursor + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[AsyncCursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + context: Optional[AdaptContext] = None, + row_factory: Optional[AsyncRowFactory[Row]] = None, + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + **kwargs: Any, + ) -> "AsyncConnection[Any]": + + if sys.platform == "win32": + loop = asyncio.get_running_loop() + if isinstance(loop, asyncio.ProactorEventLoop): + raise e.InterfaceError( + "Psycopg cannot use the 'ProactorEventLoop' to run in async" + " mode. Please use a compatible event loop, for instance by" + " setting 'asyncio.set_event_loop_policy" + "(WindowsSelectorEventLoopPolicy())'" + ) + + params = await cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = await cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + await self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + await self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + await self.close() + + @classmethod + async def _get_connection_params( + cls, conninfo: str, **kwargs: Any + ) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + .. versionchanged:: 3.1 + Unlike the sync counterpart, perform non-blocking address + resolution and populate the ``hostaddr`` connection parameter, + unless the user has provided one themselves. See + `~psycopg._dns.resolve_hostaddr_async()` for details. + + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + # Resolve host addresses in non-blocking way + params = await resolve_hostaddr_async(params) + + return params + + async def close(self) -> None: + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] + ) -> AsyncCursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: AsyncRowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[AsyncRowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: + """ + Return a new `AsyncCursor` to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + async def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> AsyncCursor[Row]: + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return await cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + async def commit(self) -> None: + async with self.lock: + await self.wait(self._commit_gen()) + + async def rollback(self) -> None: + async with self.lock: + await self.wait(self._rollback_gen()) + + @asynccontextmanager + async def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> AsyncIterator[AsyncTransaction]: + """ + Start a context block with a new transaction or nested transaction. + + :rtype: AsyncTransaction + """ + tx = AsyncTransaction(self, savepoint_name, force_rollback) + if self._pipeline: + async with self.pipeline(), tx, self.pipeline(): + yield tx + else: + async with tx: + yield tx + + async def notifies(self) -> AsyncGenerator[Notify, None]: + while True: + async with self.lock: + try: + ns = await self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @asynccontextmanager + async def pipeline(self) -> AsyncIterator[AsyncPipeline]: + """Context manager to switch the connection into pipeline mode.""" + async with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = AsyncPipeline(self) + + try: + async with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + async with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + async def wait(self, gen: PQGen[RV]) -> RV: + try: + return await waiting.wait_async(gen, self.pgconn.socket) + except KeyboardInterrupt: + # TODO: this doesn't seem to work as it does for sync connections + # see tests/test_concurrency_async.py::test_ctrl_c + # In the test, the code doesn't reach this branch. + + # On Ctrl-C, try to cancel the query in the server, otherwise + # otherwise the connection will be stuck in ACTIVE state + c = self.pgconn.get_cancel() + c.cancel() + try: + await waiting.wait_async(gen, self.pgconn.socket) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + return await waiting.wait_conn_async(gen, timeout) + + def _set_autocommit(self, value: bool) -> None: + self._no_set_async("autocommit") + + async def set_autocommit(self, value: bool) -> None: + """Async version of the `~Connection.autocommit` setter.""" + async with self.lock: + await self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._no_set_async("isolation_level") + + async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + """Async version of the `~Connection.isolation_level` setter.""" + async with self.lock: + await self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + self._no_set_async("read_only") + + async def set_read_only(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.read_only` setter.""" + async with self.lock: + await self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + self._no_set_async("deferrable") + + async def set_deferrable(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.deferrable` setter.""" + async with self.lock: + await self.wait(self._set_deferrable_gen(value)) + + def _no_set_async(self, attribute: str) -> None: + raise AttributeError( + f"'the {attribute!r} property is read-only on async connections:" + f" please use 'await .set_{attribute}()' instead." + ) + + async def tpc_begin(self, xid: Union[Xid, str]) -> None: + async with self.lock: + await self.wait(self._tpc_begin_gen(xid)) + + async def tpc_prepare(self) -> None: + try: + async with self.lock: + await self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("commit", xid)) + + async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("rollback", xid)) + + async def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + await cur.execute(Xid._get_recover_query()) + res = await cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + await self.rollback() + + return res diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py new file mode 100644 index 0000000..3b21f83 --- /dev/null +++ b/psycopg/psycopg/conninfo.py @@ -0,0 +1,378 @@ +""" +Functions to manipulate conninfo strings +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import socket +import asyncio +from typing import Any, Dict, List, Optional +from pathlib import Path +from datetime import tzinfo +from functools import lru_cache +from ipaddress import ip_address + +from . import pq +from . import errors as e +from ._tz import get_tzinfo +from ._encodings import pgconn_encoding + + +def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: + """ + Merge a string and keyword params into a single conninfo string. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: A connection string valid for PostgreSQL, with the `!kwargs` + parameters merged. + + Raise `~psycopg.ProgrammingError` if the input doesn't make a valid + conninfo string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + if not conninfo and not kwargs: + return "" + + # If no kwarg specified don't mung the conninfo but check if it's correct. + # Make sure to return a string, not a subtype, to avoid making Liskov sad. + if not kwargs: + _parse_conninfo(conninfo) + return str(conninfo) + + # Override the conninfo with the parameters + # Drop the None arguments + kwargs = {k: v for (k, v) in kwargs.items() if v is not None} + + if conninfo: + tmp = conninfo_to_dict(conninfo) + tmp.update(kwargs) + kwargs = tmp + + conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()) + + # Verify the result is valid + _parse_conninfo(conninfo) + + return conninfo + + +def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: + """ + Convert the `!conninfo` string into a dictionary of parameters. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: Dictionary with the parameters parsed from `!conninfo` and + `!kwargs`. + + Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection + string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + opts = _parse_conninfo(conninfo) + rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} + for k, v in kwargs.items(): + if v is not None: + rv[k] = v + return rv + + +def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: + """ + Verify that `!conninfo` is a valid connection string. + + Raise ProgrammingError if the string is not valid. + + Return the result of pq.Conninfo.parse() on success. + """ + try: + return pq.Conninfo.parse(conninfo.encode()) + except e.OperationalError as ex: + raise e.ProgrammingError(str(ex)) + + +re_escape = re.compile(r"([\\'])") +re_space = re.compile(r"\s") + + +def _param_escape(s: str) -> str: + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: + return "''" + + s = re_escape.sub(r"\\\1", s) + if re_space.search(s): + s = "'" + s + "'" + + return s + + +class ConnectionInfo: + """Allow access to information about the connection.""" + + __module__ = "psycopg" + + def __init__(self, pgconn: pq.abc.PGconn): + self.pgconn = pgconn + + @property + def vendor(self) -> str: + """A string representing the database vendor connected to.""" + return "PostgreSQL" + + @property + def host(self) -> str: + """The server host name of the active connection. See :pq:`PQhost()`.""" + return self._get_pgconn_attr("host") + + @property + def hostaddr(self) -> str: + """The server IP address of the connection. See :pq:`PQhostaddr()`.""" + return self._get_pgconn_attr("hostaddr") + + @property + def port(self) -> int: + """The port of the active connection. See :pq:`PQport()`.""" + return int(self._get_pgconn_attr("port")) + + @property + def dbname(self) -> str: + """The database name of the connection. See :pq:`PQdb()`.""" + return self._get_pgconn_attr("db") + + @property + def user(self) -> str: + """The user name of the connection. See :pq:`PQuser()`.""" + return self._get_pgconn_attr("user") + + @property + def password(self) -> str: + """The password of the connection. See :pq:`PQpass()`.""" + return self._get_pgconn_attr("password") + + @property + def options(self) -> str: + """ + The command-line options passed in the connection request. + See :pq:`PQoptions`. + """ + return self._get_pgconn_attr("options") + + def get_parameters(self) -> Dict[str, str]: + """Return the connection parameters values. + + Return all the parameters set to a non-default value, which might come + either from the connection string and parameters passed to + `~Connection.connect()` or from environment variables. The password + is never returned (you can read it using the `password` attribute). + """ + pyenc = self.encoding + + # Get the known defaults to avoid reporting them + defaults = { + i.keyword: i.compiled + for i in pq.Conninfo.get_defaults() + if i.compiled is not None + } + # Not returned by the libq. Bug? Bet we're using SSH. + defaults.setdefault(b"channel_binding", b"prefer") + defaults[b"passfile"] = str(Path.home() / ".pgpass").encode() + + return { + i.keyword.decode(pyenc): i.val.decode(pyenc) + for i in self.pgconn.info + if i.val is not None + and i.keyword != b"password" + and i.val != defaults.get(i.keyword) + } + + @property + def dsn(self) -> str: + """Return the connection string to connect to the database. + + The string contains all the parameters set to a non-default value, + which might come either from the connection string and parameters + passed to `~Connection.connect()` or from environment variables. The + password is never returned (you can read it using the `password` + attribute). + """ + return make_conninfo(**self.get_parameters()) + + @property + def status(self) -> pq.ConnStatus: + """The status of the connection. See :pq:`PQstatus()`.""" + return pq.ConnStatus(self.pgconn.status) + + @property + def transaction_status(self) -> pq.TransactionStatus: + """ + The current in-transaction status of the session. + See :pq:`PQtransactionStatus()`. + """ + return pq.TransactionStatus(self.pgconn.transaction_status) + + @property + def pipeline_status(self) -> pq.PipelineStatus: + """ + The current pipeline status of the client. + See :pq:`PQpipelineStatus()`. + """ + return pq.PipelineStatus(self.pgconn.pipeline_status) + + def parameter_status(self, param_name: str) -> Optional[str]: + """ + Return a parameter setting of the connection. + + Return `None` is the parameter is unknown. + """ + res = self.pgconn.parameter_status(param_name.encode(self.encoding)) + return res.decode(self.encoding) if res is not None else None + + @property + def server_version(self) -> int: + """ + An integer representing the server version. See :pq:`PQserverVersion()`. + """ + return self.pgconn.server_version + + @property + def backend_pid(self) -> int: + """ + The process ID (PID) of the backend process handling this connection. + See :pq:`PQbackendPID()`. + """ + return self.pgconn.backend_pid + + @property + def error_message(self) -> str: + """ + The error message most recently generated by an operation on the connection. + See :pq:`PQerrorMessage()`. + """ + return self._get_pgconn_attr("error_message") + + @property + def timezone(self) -> tzinfo: + """The Python timezone info of the connection's timezone.""" + return get_tzinfo(self.pgconn) + + @property + def encoding(self) -> str: + """The Python codec name of the connection's client encoding.""" + return pgconn_encoding(self.pgconn) + + def _get_pgconn_attr(self, name: str) -> str: + value: bytes = getattr(self.pgconn, name) + return value.decode(self.encoding) + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups, eventually removing hosts that are + not resolved, keeping the lists of hosts and ports consistent. + + Raise `~psycopg.OperationalError` if connection is not possible (e.g. no + host resolve, inconsistent lists length). + """ + hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) + if hostaddr_arg: + # Already resolved + return params + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + if not host_arg: + # Nothing to resolve + return params + + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") if port_arg else [] + default_port = "5432" + + if len(ports_in) == 1: + # If only one port is specified, the libpq will apply it to all + # the hosts, so don't mangle it. + default_port = ports_in.pop() + + elif len(ports_in) > 1: + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + ports_out = [] + + hosts_out = [] + hostaddr_out = [] + loop = asyncio.get_running_loop() + for i, host in enumerate(hosts_in): + if not host or host.startswith("/") or host[1:2] == ":": + # Local path + hosts_out.append(host) + hostaddr_out.append("") + if ports_in: + ports_out.append(ports_in[i]) + continue + + # If the host is already an ip address don't try to resolve it + if is_ip_address(host): + hosts_out.append(host) + hostaddr_out.append(host) + if ports_in: + ports_out.append(ports_in[i]) + continue + + try: + port = ports_in[i] if ports_in else default_port + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + except OSError as ex: + last_exc = ex + else: + for item in ans: + hosts_out.append(host) + hostaddr_out.append(item[4][0]) + if ports_in: + ports_out.append(ports_in[i]) + + # Throw an exception if no host could be resolved + if not hosts_out: + raise e.OperationalError(str(last_exc)) + + out = params.copy() + out["host"] = ",".join(hosts_out) + out["hostaddr"] = ",".join(hostaddr_out) + if ports_in: + out["port"] = ",".join(ports_out) + + return out + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py new file mode 100644 index 0000000..7514306 --- /dev/null +++ b/psycopg/psycopg/copy.py @@ -0,0 +1,904 @@ +""" +psycopg copy support +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import queue +import struct +import asyncio +import threading +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO +from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING + +from . import pq +from . import adapt +from . import errors as e +from .abc import Buffer, ConnectionType, PQGen, Transformer +from ._compat import create_task +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding +from .generators import copy_from, copy_to, copy_end + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +PY_TEXT = adapt.PyFormat.TEXT +PY_BINARY = adapt.PyFormat.BINARY + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + +ACTIVE = pq.TransactionStatus.ACTIVE + +# Size of data to accumulate before sending it down the network. We fill a +# buffer this size field by field, and when it passes the threshold size +# we ship it, so it may end up being bigger than this. +BUFFER_SIZE = 32 * 1024 + +# Maximum data size we want to queue to send to the libpq copy. Sending a +# buffer too big to be handled can cause an infinite loop in the libpq +# (#255) so we want to split it in more digestable chunks. +MAX_BUFFER_SIZE = 4 * BUFFER_SIZE +# Note: making this buffer too large, e.g. +# MAX_BUFFER_SIZE = 1024 * 1024 +# makes operations *way* slower! Probably triggering some quadraticity +# in the libpq memory management and data sending. + +# Max size of the write queue of buffers. More than that copy will block +# Each buffer should be around BUFFER_SIZE size. +QUEUE_SIZE = 1024 + + +class BaseCopy(Generic[ConnectionType]): + """ + Base implementation for the copy user interface. + + Two subclasses expose real methods with the sync/async differences. + + The difference between the text and binary format is managed by two + different `Formatter` subclasses. + + Writing (the I/O part) is implemented in the subclasses by a `Writer` or + `AsyncWriter` instance. Normally writing implies sending copy data to a + database, but a different writer might be chosen, e.g. to stream data into + a file for later use. + """ + + _Self = TypeVar("_Self", bound="BaseCopy[Any]") + + formatter: "Formatter" + + def __init__( + self, + cursor: "BaseCursor[ConnectionType, Any]", + *, + binary: Optional[bool] = None, + ): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + result = cursor.pgresult + if result: + self._direction = result.status + if self._direction != COPY_IN and self._direction != COPY_OUT: + raise e.ProgrammingError( + "the cursor should have performed a COPY operation;" + f" its status is {pq.ExecStatus(self._direction).name} instead" + ) + else: + self._direction = COPY_IN + + if binary is None: + binary = bool(result and result.binary_tuples) + + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) + if binary: + self.formatter = BinaryFormatter(tx) + else: + self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) + + self._finished = False + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def _enter(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + + def set_types(self, types: Sequence[Union[int, str]]) -> None: + """ + Set the types expected in a COPY operation. + + The types must be specified as a sequence of oid or PostgreSQL type + names (e.g. ``int4``, ``timestamptz[]``). + + This operation overcomes the lack of metadata returned by PostgreSQL + when a COPY operation begins: + + - On :sql:`COPY TO`, `!set_types()` allows to specify what types the + operation returns. If `!set_types()` is not used, the data will be + returned as unparsed strings or bytes instead of Python objects. + + - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the + database expects. This is especially useful in binary copy, because + PostgreSQL will apply no cast rule. + + """ + registry = self.cursor.adapters.types + oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] + + if self._direction == COPY_IN: + self.formatter.transformer.set_dumper_types(oids, self.formatter.format) + else: + self.formatter.transformer.set_loader_types(oids, self.formatter.format) + + # High level copy protocol generators (state change of the Copy object) + + def _read_gen(self) -> PQGen[Buffer]: + if self._finished: + return memoryview(b"") + + res = yield from copy_from(self._pgconn) + if isinstance(res, memoryview): + return res + + # res is the final PGresult + self._finished = True + + # This result is a COMMAND_OK which has info about the number of rows + # returned, but not about the columns, which is instead an information + # that was received on the COPY_OUT result at the beginning of COPY. + # So, don't replace the results in the cursor, just update the rowcount. + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return memoryview(b"") + + def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: + data = yield from self._read_gen() + if not data: + return None + + row = self.formatter.parse_row(data) + if row is None: + # Get the final result to finish the copy operation + yield from self._read_gen() + self._finished = True + return None + + return row + + def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + if not exc: + return + + if self._pgconn.transaction_status != ACTIVE: + # The server has already finished to send copy data. The connection + # is already in a good state. + return + + # Throw a cancel to the server, then consume the rest of the copy data + # (which might or might not have been already transferred entirely to + # the client, so we won't necessary see the exception associated with + # canceling). + self.connection.cancel() + try: + while (yield from self._read_gen()): + pass + except e.QueryCanceled: + pass + + +class Copy(BaseCopy["Connection[Any]"]): + """Manage a :sql:`COPY` operation. + + :param cursor: the cursor where the operation is performed. + :param binary: if `!True`, write binary format. + :param writer: the object to write to destination. If not specified, write + to the `!cursor` connection. + + Choosing `!binary` is not necessary if the cursor has executed a + :sql:`COPY` operation, because the operation result describes the format + too. The parameter is useful when a `!Copy` object is created manually and + no operation is performed on the cursor, such as when using ``writer=``\\ + `~psycopg.copy.FileWriter`. + + """ + + __module__ = "psycopg" + + writer: "Writer" + + def __init__( + self, + cursor: "Cursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["Writer"] = None, + ): + super().__init__(cursor, binary=binary) + if not writer: + writer = LibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.finish(exc_val) + + # End user sync interface + + def __iter__(self) -> Iterator[Buffer]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = self.read() + if not data: + break + yield data + + def read(self) -> Buffer: + """ + Read an unparsed row after a :sql:`COPY TO` operation. + + Return an empty string when the data is finished. + """ + return self.connection.wait(self._read_gen()) + + def rows(self) -> Iterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = self.read_row() + if record is None: + break + yield record + + def read_row(self) -> Optional[Tuple[Any, ...]]: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return self.connection.wait(self._read_row_gen()) + + def write(self, buffer: Union[Buffer, str]) -> None: + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. + + If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. + """ + data = self.formatter.write(buffer) + if data: + self._write(data) + + def write_row(self, row: Sequence[Any]) -> None: + """Write a record to a table after a :sql:`COPY FROM` operation.""" + data = self.formatter.write_row(row) + if data: + self._write(data) + + def finish(self, exc: Optional[BaseException]) -> None: + """Terminate the copy operation and free the resources allocated. + + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + self._write(data) + self.writer.finish(exc) + self._finished = True + else: + self.connection.wait(self._end_copy_out_gen(exc)) + + +class Writer(ABC): + """ + A class to write copy data somewhere. + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + """ + Write some data to destination. + """ + ... + + def finish(self, exc: Optional[BaseException] = None) -> None: + """ + Called when write operations are finished. + + If operations finished with an error, it will be passed to ``exc``. + """ + pass + + +class LibpqWriter(Writer): + """ + A `Writer` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "Cursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class QueuedLibpqDriver(LibpqWriter): + """ + A writer using a buffer to queue data to write to a Postgres database. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "Cursor[Any]"): + super().__init__(cursor) + + self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate thread. + """ + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def finish(self, exc: Optional[BaseException] = None) -> None: + self._queue.put(b"") + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + super().finish(exc) + + +class FileWriter(Writer): + """ + A `Writer` to write copy data to a file-like object. + + :param file: the file where to write copy data. It must be open for writing + in binary mode. + """ + + def __init__(self, file: IO[bytes]): + self.file = file + + def write(self, data: Buffer) -> None: + self.file.write(data) # type: ignore[arg-type] + + +class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): + """Manage an asynchronous :sql:`COPY` operation.""" + + __module__ = "psycopg" + + writer: "AsyncWriter" + + def __init__( + self, + cursor: "AsyncCursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["AsyncWriter"] = None, + ): + super().__init__(cursor, binary=binary) + + if not writer: + writer = AsyncLibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.finish(exc_val) + + async def __aiter__(self) -> AsyncIterator[Buffer]: + while True: + data = await self.read() + if not data: + break + yield data + + async def read(self) -> Buffer: + return await self.connection.wait(self._read_gen()) + + async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: + while True: + record = await self.read_row() + if record is None: + break + yield record + + async def read_row(self) -> Optional[Tuple[Any, ...]]: + return await self.connection.wait(self._read_row_gen()) + + async def write(self, buffer: Union[Buffer, str]) -> None: + data = self.formatter.write(buffer) + if data: + await self._write(data) + + async def write_row(self, row: Sequence[Any]) -> None: + data = self.formatter.write_row(row) + if data: + await self._write(data) + + async def finish(self, exc: Optional[BaseException]) -> None: + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + await self._write(data) + await self.writer.finish(exc) + self._finished = True + else: + await self.connection.wait(self._end_copy_out_gen(exc)) + + +class AsyncWriter(ABC): + """ + A class to write copy data somewhere (for async connections). + """ + + @abstractmethod + async def write(self, data: Buffer) -> None: + ... + + async def finish(self, exc: Optional[BaseException] = None) -> None: + pass + + +class AsyncLibpqWriter(AsyncWriter): + """ + An `AsyncWriter` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + async def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class AsyncQueuedLibpqWriter(AsyncLibpqWriter): + """ + An `AsyncWriter` using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + super().__init__(cursor) + + self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[asyncio.Future[None]] = None + + async def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value. + + The function is designed to be run in a separate task. + """ + while True: + data = await self._queue.get() + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) + + async def write(self, data: Buffer) -> None: + if not self._worker: + self._worker = create_task(self.worker()) + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + await self._queue.put(b"") + + if self._worker: + await asyncio.gather(self._worker) + self._worker = None # break reference loops if any + + await super().finish(exc) + + +class Formatter(ABC): + """ + A class which understand a copy format (text, binary). + """ + + format: pq.Format + + def __init__(self, transformer: Transformer): + self.transformer = transformer + self._write_buffer = bytearray() + self._row_mode = False # true if the user is using write_row() + + @abstractmethod + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + ... + + @abstractmethod + def write(self, buffer: Union[Buffer, str]) -> Buffer: + ... + + @abstractmethod + def write_row(self, row: Sequence[Any]) -> Buffer: + ... + + @abstractmethod + def end(self) -> Buffer: + ... + + +class TextFormatter(Formatter): + + format = TEXT + + def __init__(self, transformer: Transformer, encoding: str = "utf-8"): + super().__init__(transformer) + self._encoding = encoding + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if data: + return parse_row_text(data, self.transformer) + else: + return None + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + format_row_text(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + return data.encode(self._encoding) + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +class BinaryFormatter(Formatter): + + format = BINARY + + def __init__(self, transformer: Transformer): + super().__init__(transformer) + self._signature_sent = False + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if not self._signature_sent: + if data[: len(_binary_signature)] != _binary_signature: + raise e.DataError( + "binary copy doesn't start with the expected signature" + ) + self._signature_sent = True + data = data[len(_binary_signature) :] + + elif data == _binary_trailer: + return None + + return parse_row_binary(data, self.transformer) + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + if not self._signature_sent: + self._write_buffer += _binary_signature + self._signature_sent = True + + format_row_binary(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + + elif self._row_mode: + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + raise TypeError("cannot copy str data in binary mode: use bytes instead") + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +def _format_row_text( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for copy.""" + if out is None: + out = bytearray() + + if not row: + out += b"\n" + return out + + for item in row: + if item is not None: + dumper = tx.get_dumper(item, PY_TEXT) + b = dumper.dump(item) + out += _dump_re.sub(_dump_sub, b) + else: + out += rb"\N" + out += b"\t" + + out[-1:] = b"\n" + return out + + +def _format_row_binary( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for binary copy.""" + if out is None: + out = bytearray() + + out += _pack_int2(len(row)) + adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) + for b in adapted: + if b is not None: + out += _pack_int4(len(b)) + out += b + else: + out += _binary_null + + return out + + +def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + if not isinstance(data, bytes): + data = bytes(data) + fields = data.split(b"\t") + fields[-1] = fields[-1][:-1] # drop \n + row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] + return tx.load_sequence(row) + + +def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + row: List[Optional[Buffer]] = [] + nfields = _unpack_int2(data, 0)[0] + pos = 2 + for i in range(nfields): + length = _unpack_int4(data, pos)[0] + pos += 4 + if length >= 0: + row.append(data[pos : pos + length]) + pos += length + else: + row.append(None) + + return tx.load_sequence(row) + + +_pack_int2 = struct.Struct("!h").pack +_pack_int4 = struct.Struct("!i").pack +_unpack_int2 = struct.Struct("!h").unpack_from +_unpack_int4 = struct.Struct("!i").unpack_from + +_binary_signature = ( + b"PGCOPY\n\xff\r\n\0" # Signature + b"\x00\x00\x00\x00" # flags + b"\x00\x00\x00\x00" # extra length +) +_binary_trailer = b"\xff\xff" +_binary_null = b"\xff\xff\xff\xff" + +_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_dump_repl = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", +} + + +def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: + return __map[m.group(0)] + + +_load_re = re.compile(b"\\\\[btnvfr\\\\]") +_load_repl = {v: k for k, v in _dump_repl.items()} + + +def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: + return __map[m.group(0)] + + +# Override functions with fast versions if available +if _psycopg: + format_row_text = _psycopg.format_row_text + format_row_binary = _psycopg.format_row_binary + parse_row_text = _psycopg.parse_row_text + parse_row_binary = _psycopg.parse_row_binary + +else: + format_row_text = _format_row_text + format_row_binary = _format_row_binary + parse_row_text = _parse_row_text + parse_row_binary = _parse_row_binary diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py new file mode 100644 index 0000000..323903a --- /dev/null +++ b/psycopg/psycopg/crdb/__init__.py @@ -0,0 +1,19 @@ +""" +CockroachDB support package. +""" + +# Copyright (C) 2022 The Psycopg Team + +from . import _types +from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo + +adapters = _types.adapters # exposed by the package +connect = CrdbConnection.connect + +_types.register_crdb_adapters(adapters) + +__all__ = [ + "AsyncCrdbConnection", + "CrdbConnection", + "CrdbConnectionInfo", +] diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py new file mode 100644 index 0000000..5311e05 --- /dev/null +++ b/psycopg/psycopg/crdb/_types.py @@ -0,0 +1,163 @@ +""" +Types configuration specific for CockroachDB. +""" + +# Copyright (C) 2022 The Psycopg Team + +from enum import Enum +from .._typeinfo import TypeInfo, TypesRegistry + +from ..abc import AdaptContext, NoneType +from ..postgres import TEXT_OID +from .._adapters_map import AdaptersMap +from ..types.enum import EnumDumper, EnumBinaryDumper +from ..types.none import NoneDumper + +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + + +class CrdbEnumDumper(EnumDumper): + oid = TEXT_OID + + +class CrdbEnumBinaryDumper(EnumBinaryDumper): + oid = TEXT_OID + + +class CrdbNoneDumper(NoneDumper): + oid = TEXT_OID + + +def register_postgres_adapters(context: AdaptContext) -> None: + # Same adapters used by PostgreSQL, or a good starting point for customization + + from ..types import array, bool, composite, datetime + from ..types import numeric, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + numeric.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) + + +def register_crdb_adapters(context: AdaptContext) -> None: + from .. import dbapi20 + from ..types import array + + register_postgres_adapters(context) + + # String must come after enum to map text oid -> string dumper + register_crdb_enum_adapters(context) + register_crdb_string_adapters(context) + register_crdb_json_adapters(context) + register_crdb_net_adapters(context) + register_crdb_none_adapters(context) + + dbapi20.register_dbapi20_adapters(adapters) + + array.register_all_arrays(adapters) + + +def register_crdb_string_adapters(context: AdaptContext) -> None: + from ..types import string + + # Dump strings with text oid instead of unknown. + # Unlike PostgreSQL, CRDB seems able to cast text to most types. + context.adapters.register_dumper(str, string.StrDumper) + context.adapters.register_dumper(str, string.StrBinaryDumper) + + +def register_crdb_enum_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper) + context.adapters.register_dumper(Enum, CrdbEnumDumper) + + +def register_crdb_json_adapters(context: AdaptContext) -> None: + from ..types import json + + adapters = context.adapters + + # CRDB doesn't have json/jsonb: both names map to the jsonb oid + adapters.register_dumper(json.Json, json.JsonbBinaryDumper) + adapters.register_dumper(json.Json, json.JsonbDumper) + + adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper) + adapters.register_dumper(json.Jsonb, json.JsonbDumper) + + adapters.register_loader("json", json.JsonLoader) + adapters.register_loader("jsonb", json.JsonbLoader) + adapters.register_loader("json", json.JsonBinaryLoader) + adapters.register_loader("jsonb", json.JsonbBinaryLoader) + + +def register_crdb_net_adapters(context: AdaptContext) -> None: + from ..types import net + + adapters = context.adapters + + adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper) + adapters.register_dumper(None, net.InetBinaryDumper) + adapters.register_loader("inet", net.InetLoader) + adapters.register_loader("inet", net.InetBinaryLoader) + + +def register_crdb_none_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, CrdbNoneDumper) + + +for t in [ + TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb. + TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8 + TypeInfo('"char"', 18, 1002), # special case, not generated + # autogenerated: start + # Generated from CockroachDB 22.1.0 + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("record", 2249, 2287), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("unknown", 705, 0), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + # autogenerated: end +]: + types.add(t) diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py new file mode 100644 index 0000000..6e79ed1 --- /dev/null +++ b/psycopg/psycopg/crdb/connection.py @@ -0,0 +1,186 @@ +""" +CockroachDB-specific connections. +""" + +# Copyright (C) 2022 The Psycopg Team + +import re +from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING + +from .. import errors as e +from ..abc import AdaptContext +from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow +from ..conninfo import ConnectionInfo +from ..connection import Connection +from .._adapters_map import AdaptersMap +from ..connection_async import AsyncConnection +from ._types import adapters + +if TYPE_CHECKING: + from ..pq.abc import PGconn + from ..cursor import Cursor + from ..cursor_async import AsyncCursor + + +class _CrdbConnectionMixin: + + _adapters: Optional[AdaptersMap] + pgconn: "PGconn" + + @classmethod + def is_crdb( + cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"] + ) -> bool: + """ + Return `!True` if the server connected to `!conn` is CockroachDB. + """ + if isinstance(conn, (Connection, AsyncConnection)): + conn = conn.pgconn + + return bool(conn.parameter_status(b"crdb_version")) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + # By default, use CockroachDB adapters map + self._adapters = AdaptersMap(adapters) + + return self._adapters + + @property + def info(self) -> "CrdbConnectionInfo": + return CrdbConnectionInfo(self.pgconn) + + def _check_tpc(self) -> None: + if self.is_crdb(self.pgconn): + raise e.NotSupportedError("CockroachDB doesn't support prepared statements") + + +class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): + """ + Wrapper for a connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[Row]": + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[TupleRow]": + ... + + @classmethod + def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]": + """ + Connect to a database server and return a new `CrdbConnection` instance. + """ + return super().connect(conninfo, **kwargs) # type: ignore[return-value] + + +class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]): + """ + Wrapper for an async connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[Row]": + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[TupleRow]": + ... + + @classmethod + async def connect( + cls, conninfo: str = "", **kwargs: Any + ) -> "AsyncCrdbConnection[Any]": + return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return] + + +class CrdbConnectionInfo(ConnectionInfo): + """ + `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + @property + def vendor(self) -> str: + return "CockroachDB" + + @property + def server_version(self) -> int: + """ + Return the CockroachDB server version connected. + + Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210). + """ + sver = self.parameter_status("crdb_version") + if not sver: + raise e.InternalError("'crdb_version' parameter status not set") + + ver = self.parse_crdb_version(sver) + if ver is None: + raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}") + + return ver + + @classmethod + def parse_crdb_version(self, sver: str) -> Optional[int]: + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + return None + + return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py new file mode 100644 index 0000000..42c3804 --- /dev/null +++ b/psycopg/psycopg/cursor.py @@ -0,0 +1,921 @@ +""" +psycopg cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from functools import partial +from types import TracebackType +from typing import Any, Generic, Iterable, Iterator, List +from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar +from typing import overload, TYPE_CHECKING +from contextlib import contextmanager + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .copy import Copy, Writer as CopyWriter +from .rows import Row, RowMaker, RowFactory +from ._column import Column +from ._queries import PostgresQuery, PostgresClientQuery +from ._pipeline import Pipeline +from ._encodings import pgconn_encoding +from ._preparing import Prepare +from .generators import execute, fetch, send + +if TYPE_CHECKING: + from .abc import Transformer + from .pq.abc import PGconn, PGresult + from .connection import Connection + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class BaseCursor(Generic[ConnectionType, Row]): + __slots__ = """ + _conn format _adapters arraysize _closed _results pgresult _pos + _iresult _rowcount _query _tx _last_query _row_factory _make_row + _pgconn _execmany_returning + __weakref__ + """.split() + + ExecStatus = pq.ExecStatus + + _tx: "Transformer" + _make_row: RowMaker[Row] + _pgconn: "PGconn" + + def __init__(self, connection: ConnectionType): + self._conn = connection + self.format = TEXT + self._pgconn = connection.pgconn + self._adapters = adapt.AdaptersMap(connection.adapters) + self.arraysize = 1 + self._closed = False + self._last_query: Optional[Query] = None + self._reset() + + def _reset(self, reset_query: bool = True) -> None: + self._results: List["PGresult"] = [] + self.pgresult: Optional["PGresult"] = None + self._pos = 0 + self._iresult = 0 + self._rowcount = -1 + self._query: Optional[PostgresQuery] + # None if executemany() not executing, True/False according to returning state + self._execmany_returning: Optional[bool] = None + if reset_query: + self._query = None + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + if self._closed: + status = "closed" + elif self.pgresult: + status = pq.ExecStatus(self.pgresult.status).name + else: + status = "no result" + return f"<{cls} [{status}] {info} at 0x{id(self):x}>" + + @property + def connection(self) -> ConnectionType: + """The connection this cursor is using.""" + return self._conn + + @property + def adapters(self) -> adapt.AdaptersMap: + return self._adapters + + @property + def closed(self) -> bool: + """`True` if the cursor is closed.""" + return self._closed + + @property + def description(self) -> Optional[List[Column]]: + """ + A list of `Column` objects describing the current resultset. + + `!None` if the current resultset didn't return tuples. + """ + res = self.pgresult + + # We return columns if we have nfields, but also if we don't but + # the query said we got tuples (mostly to handle the super useful + # query "SELECT ;" + if res and ( + res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE + ): + return [Column(self, i) for i in range(res.nfields)] + else: + return None + + @property + def rowcount(self) -> int: + """Number of records affected by the precedent operation.""" + return self._rowcount + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + tuples = self.pgresult and self.pgresult.status == TUPLES_OK + return self._pos if tuples else None + + def setinputsizes(self, sizes: Sequence[Any]) -> None: + # no-op + pass + + def setoutputsize(self, size: Any, column: Optional[int] = None) -> None: + # no-op + pass + + def nextset(self) -> Optional[bool]: + """ + Move to the result set of the next query executed through `executemany()` + or to the next result set if `execute()` returned more than one. + + Return `!True` if a new result is available, which will be the one + methods `!fetch*()` will operate on. + """ + if self._iresult < len(self._results) - 1: + self._select_current_result(self._iresult + 1) + return True + else: + return None + + @property + def statusmessage(self) -> Optional[str]: + """ + The command status tag from the last SQL command executed. + + `!None` if the cursor doesn't have a result available. + """ + msg = self.pgresult.command_status if self.pgresult else None + return msg.decode() if msg else None + + def _make_row_maker(self) -> RowMaker[Row]: + raise NotImplementedError + + # + # Generators for the high level operations on the cursor + # + # Like for sync/async connections, these are implemented as generators + # so that different concurrency strategies (threads,asyncio) can use their + # own way of waiting (or better, `connection.wait()`). + # + + def _execute_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `Cursor.execute()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + results = yield from self._maybe_prepare_gen( + pgq, prepare=prepare, binary=binary + ) + if self._conn._pipeline: + yield from self._conn._pipeline._communicate_gen() + else: + assert results is not None + self._check_results(results) + self._results = results + self._select_current_result(0) + + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines available. + """ + pipeline = self._conn._pipeline + assert pipeline + + yield from self._start_query(query) + self._rowcount = 0 + + assert self._execmany_returning is None + self._execmany_returning = returning + + first = True + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + yield from self._maybe_prepare_gen(pgq, prepare=True) + yield from pipeline._communicate_gen() + + self._last_query = query + + if returning: + yield from pipeline._fetch_gen(flush=True) + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_no_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines not available. + """ + yield from self._start_query(query) + first = True + nrows = 0 + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + results = yield from self._maybe_prepare_gen(pgq, prepare=True) + assert results is not None + self._check_results(results) + if returning: + self._results.extend(results) + + for res in results: + nrows += res.command_tuples or 0 + + if self._results: + self._select_current_result(0) + + # Override rowcount for the first result. Calls to nextset() will change + # it to the value of that result only, but we hope nobody will notice. + # You haven't read this comment. + self._rowcount = nrows + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _maybe_prepare_gen( + self, + pgq: PostgresQuery, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[Optional[List["PGresult"]]]: + # Check if the query is prepared or needs preparing + prep, name = self._get_prepared(pgq, prepare) + if prep is Prepare.NO: + # The query must be executed without preparing + self._execute_send(pgq, binary=binary) + else: + # If the query is not already prepared, prepare it. + if prep is Prepare.SHOULD: + self._send_prepare(name, pgq) + if not self._conn._pipeline: + (result,) = yield from execute(self._pgconn) + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + # Then execute it. + self._send_query_prepared(name, pgq, binary=binary) + + # Update the prepare state of the query. + # If an operation requires to flush our prepared statements cache, + # it will be added to the maintenance commands to execute later. + key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name) + + if self._conn._pipeline: + queued = None + if key is not None: + queued = (key, prep, name) + self._conn._pipeline.result_queue.append((self, queued)) + return None + + # run the query + results = yield from execute(self._pgconn) + + if key is not None: + self._conn._prepared.validate(key, prep, name, results) + + return results + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return self._conn._prepared.get(pgq, prepare) + + def _stream_send_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator to send the query for `Cursor.stream()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, binary=binary, force_extended=True) + self._pgconn.set_single_row_mode() + self._last_query = query + yield from send(self._pgconn) + + def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]: + res = yield from fetch(self._pgconn) + if res is None: + return None + + status = res.status + if status == SINGLE_TUPLE: + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=first) + if first: + self._make_row = self._make_row_maker() + return res + + elif status == TUPLES_OK or status == COMMAND_OK: + # End of single row results + while res: + res = yield from fetch(self._pgconn) + if status != TUPLES_OK: + raise e.ProgrammingError( + "the operation in stream() didn't produce a result" + ) + return None + + else: + # Errors, unexpected values + return self._raise_for_result(res) + + def _start_query(self, query: Optional[Query] = None) -> PQGen[None]: + """Generator to start the processing of a query. + + It is implemented as generator because it may send additional queries, + such as `begin`. + """ + if self.closed: + raise e.InterfaceError("the cursor is closed") + + self._reset() + if not self._last_query or (self._last_query is not query): + self._last_query = None + self._tx = adapt.Transformer(self) + yield from self._conn._start_query() + + def _start_copy_gen( + self, statement: Query, params: Optional[Params] = None + ) -> PQGen[None]: + """Generator implementing sending a command for `Cursor.copy().""" + + # The connection gets in an unrecoverable state if we attempt COPY in + # pipeline mode. Forbid it explicitly. + if self._conn._pipeline: + raise e.NotSupportedError("COPY cannot be used in pipeline mode") + + yield from self._start_query() + + # Merge the params client-side + if params: + pgq = PostgresClientQuery(self._tx) + pgq.convert(statement, params) + statement = pgq.query + + query = self._convert_query(statement) + + self._execute_send(query, binary=False) + results = yield from execute(self._pgconn) + if len(results) != 1: + raise e.ProgrammingError("COPY cannot be mixed with other operations") + + self._check_copy_result(results[0]) + self._results = results + self._select_current_result(0) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + """ + Implement part of execute() before waiting common to sync and async. + + This is not a generator, but a normal non-blocking function. + """ + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_params, + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + ) + elif force_extended or query.params or fmt == BINARY: + self._pgconn.send_query_params( + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _check_results(self, results: List["PGresult"]) -> None: + """ + Verify that the results of a query are valid. + + Verify that the query returned at least one result and that they all + represent a valid result from the database. + """ + if not results: + raise e.InternalError("got no result from the query") + + for res in results: + status = res.status + if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY: + self._raise_for_result(res) + + def _raise_for_result(self, result: "PGresult") -> NoReturn: + """ + Raise an appropriate error message for an unexpected database result + """ + status = result.status + if status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + raise e.ProgrammingError( + "COPY cannot be used with this method; use copy() instead" + ) + else: + raise e.InternalError( + "unexpected result status from query:" f" {pq.ExecStatus(status).name}" + ) + + def _select_current_result( + self, i: int, format: Optional[pq.Format] = None + ) -> None: + """ + Select one of the results in the cursor as the active one. + """ + self._iresult = i + res = self.pgresult = self._results[i] + + # Note: the only reason to override format is to correctly set + # binary loaders on server-side cursors, because send_describe_portal + # only returns a text result. + self._tx.set_pgresult(res, format=format) + + self._pos = 0 + + if res.status == TUPLES_OK: + self._rowcount = self.pgresult.ntuples + + # COPY_OUT has never info about nrows. We need such result for the + # columns in order to return a `description`, but not overwrite the + # cursor rowcount (which was set by the Copy object). + elif res.status != COPY_OUT: + nrows = self.pgresult.command_tuples + self._rowcount = nrows if nrows is not None else -1 + + self._make_row = self._make_row_maker() + + def _set_results_from_pipeline(self, results: List["PGresult"]) -> None: + self._check_results(results) + first_batch = not self._results + + if self._execmany_returning is None: + # Received from execute() + self._results.extend(results) + if first_batch: + self._select_current_result(0) + + else: + # Received from executemany() + if self._execmany_returning: + self._results.extend(results) + if first_batch: + self._select_current_result(0) + self._rowcount = 0 + + # Override rowcount for the first result. Calls to nextset() will + # change it to the value of that result only, but we hope nobody + # will notice. + # You haven't read this comment. + if self._rowcount < 0: + self._rowcount = 0 + for res in results: + self._rowcount += res.command_tuples or 0 + + def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_prepare, + name, + query.query, + param_types=query.types, + ) + ) + self._conn._pipeline.result_queue.append(None) + else: + self._pgconn.send_prepare(name, query.query, param_types=query.types) + + def _send_query_prepared( + self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_prepared, + name, + pgq.params, + param_formats=pgq.formats, + result_format=fmt, + ) + ) + else: + self._pgconn.send_query_prepared( + name, pgq.params, param_formats=pgq.formats, result_format=fmt + ) + + def _check_result_for_fetch(self) -> None: + if self.closed: + raise e.InterfaceError("the cursor is closed") + res = self.pgresult + if not res: + raise e.ProgrammingError("no result available") + + status = res.status + if status == TUPLES_OK: + return + elif status == FATAL_ERROR: + raise e.error_from_result(res, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + raise e.ProgrammingError("the last operation didn't produce a result") + + def _check_copy_result(self, result: "PGresult") -> None: + """ + Check that the value returned in a copy() operation is a legit COPY. + """ + status = result.status + if status == COPY_IN or status == COPY_OUT: + return + elif status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + else: + raise e.ProgrammingError( + "copy() should be used only with COPY ... TO STDOUT or COPY ..." + f" FROM STDIN statements, got {pq.ExecStatus(status).name}" + ) + + def _scroll(self, value: int, mode: str) -> None: + self._check_result_for_fetch() + assert self.pgresult + if mode == "relative": + newpos = self._pos + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + if not 0 <= newpos < self.pgresult.ntuples: + raise IndexError("position out of bound") + self._pos = newpos + + def _close(self) -> None: + """Non-blocking part of closing. Common to sync/async.""" + # Don't reset the query because it may be useful to investigate after + # an error. + self._reset(reset_query=False) + self._closed = True + + @property + def _encoding(self) -> str: + return pgconn_encoding(self._pgconn) + + +class Cursor(BaseCursor["Connection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="Cursor[Any]") + + @overload + def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): + ... + + @overload + def __init__( + self: "Cursor[Row]", + connection: "Connection[Any]", + *, + row_factory: RowFactory[Row], + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + *, + row_factory: Optional[RowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + self._close() + + @property + def row_factory(self) -> RowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: RowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + """ + Execute a query or command to the database. + """ + try: + with self._conn.lock: + self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + """ + Execute the same command with a sequence of input data. + """ + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + with self._conn.lock: + p = self._conn._pipeline + if p: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + with self._conn.pipeline(), self._conn.lock: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + with self._conn.lock: + self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> Iterator[Row]: + """ + Iterate row-by-row on a result from the database. + """ + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + with self._conn.lock: + + try: + self._conn.wait(self._stream_send_gen(query, params, binary=binary)) + first = True + while self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while self._conn.wait(self._stream_fetchone_gen(first=False)): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + def fetchone(self) -> Optional[Row]: + """ + Return the next record from the current recordset. + + Return `!None` the recordset is finished. + + :rtype: Optional[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + record = self._tx.load_row(self._pos, self._make_row) + if record is not None: + self._pos += 1 + return record + + def fetchmany(self, size: int = 0) -> List[Row]: + """ + Return the next `!size` records from the current recordset. + + `!size` default to `!self.arraysize` if not specified. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + def fetchall(self) -> List[Row]: + """ + Return all the remaining records from the current recordset. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + def __iter__(self) -> Iterator[Row]: + self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + def scroll(self, value: int, mode: str = "relative") -> None: + """ + Move the cursor in the result set to a new position according to mode. + + If `!mode` is ``'relative'`` (default), `!value` is taken as offset to + the current position in the result set; if set to ``'absolute'``, + `!value` states an absolute target position. + + Raise `!IndexError` in case a scroll operation would leave the result + set. In this case the position will not change. + """ + self._fetch_pipeline() + self._scroll(value, mode) + + @contextmanager + def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[CopyWriter] = None, + ) -> Iterator[Copy]: + """ + Initiate a :sql:`COPY` operation and return an object to manage it. + + :rtype: Copy + """ + try: + with self._conn.lock: + self._conn.wait(self._start_copy_gen(statement, params)) + + with Copy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + # If a fresher result has been set on the cursor by the Copy object, + # read its properties (especially rowcount). + self._select_current_result(0) + + def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + with self._conn.lock: + self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py new file mode 100644 index 0000000..8971d40 --- /dev/null +++ b/psycopg/psycopg/cursor_async.py @@ -0,0 +1,250 @@ +""" +psycopg async cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from types import TracebackType +from typing import Any, AsyncIterator, Iterable, List +from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from .abc import Query, Params +from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter +from .rows import Row, RowMaker, AsyncRowFactory +from .cursor import BaseCursor +from ._pipeline import Pipeline + +if TYPE_CHECKING: + from .connection_async import AsyncConnection + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncCursor[Any]") + + @overload + def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): + ... + + @overload + def __init__( + self: "AsyncCursor[Row]", + connection: "AsyncConnection[Any]", + *, + row_factory: AsyncRowFactory[Row], + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def close(self) -> None: + self._close() + + @property + def row_factory(self) -> AsyncRowFactory[Row]: + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + try: + async with self._conn.lock: + await self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + async with self._conn.lock: + p = self._conn._pipeline + if p: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + async with self._conn.pipeline(), self._conn.lock: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + await self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + async def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> AsyncIterator[Row]: + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + async with self._conn.lock: + + try: + await self._conn.wait( + self._stream_send_gen(query, params, binary=binary) + ) + first = True + while await self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while await self._conn.wait( + self._stream_fetchone_gen(first=False) + ): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + await self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + async def fetchone(self) -> Optional[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + rv = self._tx.load_row(self._pos, self._make_row) + if rv is not None: + self._pos += 1 + return rv + + async def fetchmany(self, size: int = 0) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + async def fetchall(self) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + async def __aiter__(self) -> AsyncIterator[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + async def scroll(self, value: int, mode: str = "relative") -> None: + self._scroll(value, mode) + + @asynccontextmanager + async def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[AsyncCopyWriter] = None, + ) -> AsyncIterator[AsyncCopy]: + """ + :rtype: AsyncCopy + """ + try: + async with self._conn.lock: + await self._conn.wait(self._start_copy_gen(statement, params)) + + async with AsyncCopy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + self._select_current_result(0) + + async def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + async with self._conn.lock: + await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py new file mode 100644 index 0000000..3c3d8b7 --- /dev/null +++ b/psycopg/psycopg/dbapi20.py @@ -0,0 +1,112 @@ +""" +Compatibility objects with DBAPI 2.0 +""" + +# Copyright (C) 2020 The Psycopg Team + +import time +import datetime as dt +from math import floor +from typing import Any, Sequence, Union + +from . import postgres +from .abc import AdaptContext, Buffer +from .types.string import BytesDumper, BytesBinaryDumper + + +class DBAPITypeObject: + def __init__(self, name: str, type_names: Sequence[str]): + self.name = name + self.values = tuple(postgres.types[n].oid for n in type_names) + + def __repr__(self) -> str: + return f"psycopg.{self.name}" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, int): + return other in self.values + else: + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, int): + return other not in self.values + else: + return NotImplemented + + +BINARY = DBAPITypeObject("BINARY", ("bytea",)) +DATETIME = DBAPITypeObject( + "DATETIME", "timestamp timestamptz date time timetz interval".split() +) +NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split()) +ROWID = DBAPITypeObject("ROWID", ("oid",)) +STRING = DBAPITypeObject("STRING", "text varchar bpchar".split()) + + +class Binary: + def __init__(self, obj: Any): + self.obj = obj + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)" + return f"{self.__class__.__name__}({sobj})" + + +class BinaryBinaryDumper(BytesBinaryDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +class BinaryTextDumper(BytesDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +def Date(year: int, month: int, day: int) -> dt.date: + return dt.date(year, month, day) + + +def DateFromTicks(ticks: float) -> dt.date: + return TimestampFromTicks(ticks).date() + + +def Time(hour: int, minute: int, second: int) -> dt.time: + return dt.time(hour, minute, second) + + +def TimeFromTicks(ticks: float) -> dt.time: + return TimestampFromTicks(ticks).time() + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> dt.datetime: + return dt.datetime(year, month, day, hour, minute, second) + + +def TimestampFromTicks(ticks: float) -> dt.datetime: + secs = floor(ticks) + frac = ticks - secs + t = time.localtime(ticks) + tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff)) + rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo) + return rv + + +def register_dbapi20_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Binary, BinaryTextDumper) + adapters.register_dumper(Binary, BinaryBinaryDumper) + + # Make them also the default dumpers when dumping by bytea oid + adapters.register_dumper(None, BinaryTextDumper) + adapters.register_dumper(None, BinaryBinaryDumper) diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py new file mode 100644 index 0000000..e176954 --- /dev/null +++ b/psycopg/psycopg/errors.py @@ -0,0 +1,1535 @@ +""" +psycopg exceptions + +DBAPI-defined Exceptions are defined in the following hierarchy:: + + Exceptions + |__Warning + |__Error + |__InterfaceError + |__DatabaseError + |__DataError + |__OperationalError + |__IntegrityError + |__InternalError + |__ProgrammingError + |__NotSupportedError +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing_extensions import TypeAlias + +from .pq.abc import PGconn, PGresult +from .pq._enums import DiagnosticField +from ._compat import TypeGuard + +ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]] + +_sqlcodes: Dict[str, "Type[Error]"] = {} + + +class Warning(Exception): + """ + Exception raised for important warnings. + + Defined for DBAPI compatibility, but never raised by ``psycopg``. + """ + + __module__ = "psycopg" + + +class Error(Exception): + """ + Base exception for all the errors psycopg will raise. + + Exception that is the base class of all other error exceptions. You can + use this to catch all errors with one single `!except` statement. + + This exception is guaranteed to be picklable. + """ + + __module__ = "psycopg" + + sqlstate: Optional[str] = None + + def __init__( + self, + *args: Sequence[Any], + info: ErrorInfo = None, + encoding: str = "utf-8", + pgconn: Optional[PGconn] = None + ): + super().__init__(*args) + self._info = info + self._encoding = encoding + self._pgconn = pgconn + + # Handle sqlstate codes for which we don't have a class. + if not self.sqlstate and info: + self.sqlstate = self.diag.sqlstate + + @property + def pgconn(self) -> Optional[PGconn]: + """The connection object, if the error was raised from a connection attempt. + + :rtype: Optional[psycopg.pq.PGconn] + """ + return self._pgconn if self._pgconn else None + + @property + def pgresult(self) -> Optional[PGresult]: + """The result object, if the exception was raised after a failed query. + + :rtype: Optional[psycopg.pq.PGresult] + """ + return self._info if _is_pgresult(self._info) else None + + @property + def diag(self) -> "Diagnostic": + """ + A `Diagnostic` object to inspect details of the errors from the database. + """ + return Diagnostic(self._info, encoding=self._encoding) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + # To make the exception picklable + res[2]["_info"] = _info_to_dict(self._info) + res[2]["_pgconn"] = None + + return res + + +class InterfaceError(Error): + """ + An error related to the database interface rather than the database itself. + """ + + __module__ = "psycopg" + + +class DatabaseError(Error): + """ + Exception raised for errors that are related to the database. + """ + + __module__ = "psycopg" + + def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None): + if code: + _sqlcodes[code] = cls + cls.sqlstate = code + if name: + _sqlcodes[name] = cls + + +class DataError(DatabaseError): + """ + An error caused by problems with the processed data. + + Examples may be division by zero, numeric value out of range, etc. + """ + + __module__ = "psycopg" + + +class OperationalError(DatabaseError): + """ + An error related to the database's operation. + + These errors are not necessarily under the control of the programmer, e.g. + an unexpected disconnect occurs, the data source name is not found, a + transaction could not be processed, a memory allocation error occurred + during processing, etc. + """ + + __module__ = "psycopg" + + +class IntegrityError(DatabaseError): + """ + An error caused when the relational integrity of the database is affected. + + An example may be a foreign key check failed. + """ + + __module__ = "psycopg" + + +class InternalError(DatabaseError): + """ + An error generated when the database encounters an internal error, + + Examples could be the cursor is not valid anymore, the transaction is out + of sync, etc. + """ + + __module__ = "psycopg" + + +class ProgrammingError(DatabaseError): + """ + Exception raised for programming errors + + Examples may be table not found or already exists, syntax error in the SQL + statement, wrong number of parameters specified, etc. + """ + + __module__ = "psycopg" + + +class NotSupportedError(DatabaseError): + """ + A method or database API was used which is not supported by the database. + """ + + __module__ = "psycopg" + + +class ConnectionTimeout(OperationalError): + """ + Exception raised on timeout of the `~psycopg.Connection.connect()` method. + + The error is raised if the ``connect_timeout`` is specified and a + connection is not obtained in useful time. + + Subclass of `~psycopg.OperationalError`. + """ + + +class PipelineAborted(OperationalError): + """ + Raised when a operation fails because the current pipeline is in aborted state. + + Subclass of `~psycopg.OperationalError`. + """ + + +class Diagnostic: + """Details from a database error report.""" + + def __init__(self, info: ErrorInfo, encoding: str = "utf-8"): + self._info = info + self._encoding = encoding + + @property + def severity(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY) + + @property + def severity_nonlocalized(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED) + + @property + def sqlstate(self) -> Optional[str]: + return self._error_message(DiagnosticField.SQLSTATE) + + @property + def message_primary(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_PRIMARY) + + @property + def message_detail(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_DETAIL) + + @property + def message_hint(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_HINT) + + @property + def statement_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.STATEMENT_POSITION) + + @property + def internal_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_POSITION) + + @property + def internal_query(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_QUERY) + + @property + def context(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONTEXT) + + @property + def schema_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.SCHEMA_NAME) + + @property + def table_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.TABLE_NAME) + + @property + def column_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.COLUMN_NAME) + + @property + def datatype_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.DATATYPE_NAME) + + @property + def constraint_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONSTRAINT_NAME) + + @property + def source_file(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FILE) + + @property + def source_line(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_LINE) + + @property + def source_function(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FUNCTION) + + def _error_message(self, field: DiagnosticField) -> Optional[str]: + if self._info: + if isinstance(self._info, dict): + val = self._info.get(field) + else: + val = self._info.error_field(field) + + if val is not None: + return val.decode(self._encoding, "replace") + + return None + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + res[2]["_info"] = _info_to_dict(self._info) + + return res + + +def _info_to_dict(info: ErrorInfo) -> ErrorInfo: + """ + Convert a PGresult to a dictionary to make the info picklable. + """ + # PGresult is a protocol, can't use isinstance + if _is_pgresult(info): + return {v: info.error_field(v) for v in DiagnosticField} + else: + return info + + +def lookup(sqlstate: str) -> Type[Error]: + """Lookup an error code or `constant name`__ and return its exception class. + + Raise `!KeyError` if the code is not found. + + .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html + #ERRCODES-TABLE + """ + return _sqlcodes[sqlstate.upper()] + + +def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error: + from psycopg import pq + + state = result.error_field(DiagnosticField.SQLSTATE) or b"" + cls = _class_for_state(state.decode("ascii")) + return cls( + pq.error_message(result, encoding=encoding), + info=result, + encoding=encoding, + ) + + +def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]: + """Return True if an ErrorInfo is a PGresult instance.""" + # PGresult is a protocol, can't use isinstance + return hasattr(info, "error_field") + + +def _class_for_state(sqlstate: str) -> Type[Error]: + try: + return lookup(sqlstate) + except KeyError: + return get_base_exception(sqlstate) + + +def get_base_exception(sqlstate: str) -> Type[Error]: + return ( + _base_exc_map.get(sqlstate[:2]) + or _base_exc_map.get(sqlstate[:1]) + or DatabaseError + ) + + +_base_exc_map = { + "08": OperationalError, # Connection Exception + "0A": NotSupportedError, # Feature Not Supported + "20": ProgrammingError, # Case Not Foud + "21": ProgrammingError, # Cardinality Violation + "22": DataError, # Data Exception + "23": IntegrityError, # Integrity Constraint Violation + "24": InternalError, # Invalid Cursor State + "25": InternalError, # Invalid Transaction State + "26": ProgrammingError, # Invalid SQL Statement Name * + "27": OperationalError, # Triggered Data Change Violation + "28": OperationalError, # Invalid Authorization Specification + "2B": InternalError, # Dependent Privilege Descriptors Still Exist + "2D": InternalError, # Invalid Transaction Termination + "2F": OperationalError, # SQL Routine Exception * + "34": ProgrammingError, # Invalid Cursor Name * + "38": OperationalError, # External Routine Exception * + "39": OperationalError, # External Routine Invocation Exception * + "3B": OperationalError, # Savepoint Exception * + "3D": ProgrammingError, # Invalid Catalog Name + "3F": ProgrammingError, # Invalid Schema Name + "40": OperationalError, # Transaction Rollback + "42": ProgrammingError, # Syntax Error or Access Rule Violation + "44": ProgrammingError, # WITH CHECK OPTION Violation + "53": OperationalError, # Insufficient Resources + "54": OperationalError, # Program Limit Exceeded + "55": OperationalError, # Object Not In Prerequisite State + "57": OperationalError, # Operator Intervention + "58": OperationalError, # System Error (errors external to PostgreSQL itself) + "F": OperationalError, # Configuration File Error + "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED) + "P": ProgrammingError, # PL/pgSQL Error + "X": InternalError, # Internal Error +} + + +# Error classes generated by tools/update_errors.py + +# fmt: off +# autogenerated: start + + +# Class 02 - No Data (this is also a warning class per the SQL standard) + +class NoData(DatabaseError, + code='02000', name='NO_DATA'): + pass + +class NoAdditionalDynamicResultSetsReturned(DatabaseError, + code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'): + pass + + +# Class 03 - SQL Statement Not Yet Complete + +class SqlStatementNotYetComplete(DatabaseError, + code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'): + pass + + +# Class 08 - Connection Exception + +class ConnectionException(OperationalError, + code='08000', name='CONNECTION_EXCEPTION'): + pass + +class SqlclientUnableToEstablishSqlconnection(OperationalError, + code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'): + pass + +class ConnectionDoesNotExist(OperationalError, + code='08003', name='CONNECTION_DOES_NOT_EXIST'): + pass + +class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError, + code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'): + pass + +class ConnectionFailure(OperationalError, + code='08006', name='CONNECTION_FAILURE'): + pass + +class TransactionResolutionUnknown(OperationalError, + code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'): + pass + +class ProtocolViolation(OperationalError, + code='08P01', name='PROTOCOL_VIOLATION'): + pass + + +# Class 09 - Triggered Action Exception + +class TriggeredActionException(DatabaseError, + code='09000', name='TRIGGERED_ACTION_EXCEPTION'): + pass + + +# Class 0A - Feature Not Supported + +class FeatureNotSupported(NotSupportedError, + code='0A000', name='FEATURE_NOT_SUPPORTED'): + pass + + +# Class 0B - Invalid Transaction Initiation + +class InvalidTransactionInitiation(DatabaseError, + code='0B000', name='INVALID_TRANSACTION_INITIATION'): + pass + + +# Class 0F - Locator Exception + +class LocatorException(DatabaseError, + code='0F000', name='LOCATOR_EXCEPTION'): + pass + +class InvalidLocatorSpecification(DatabaseError, + code='0F001', name='INVALID_LOCATOR_SPECIFICATION'): + pass + + +# Class 0L - Invalid Grantor + +class InvalidGrantor(DatabaseError, + code='0L000', name='INVALID_GRANTOR'): + pass + +class InvalidGrantOperation(DatabaseError, + code='0LP01', name='INVALID_GRANT_OPERATION'): + pass + + +# Class 0P - Invalid Role Specification + +class InvalidRoleSpecification(DatabaseError, + code='0P000', name='INVALID_ROLE_SPECIFICATION'): + pass + + +# Class 0Z - Diagnostics Exception + +class DiagnosticsException(DatabaseError, + code='0Z000', name='DIAGNOSTICS_EXCEPTION'): + pass + +class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError, + code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'): + pass + + +# Class 20 - Case Not Found + +class CaseNotFound(ProgrammingError, + code='20000', name='CASE_NOT_FOUND'): + pass + + +# Class 21 - Cardinality Violation + +class CardinalityViolation(ProgrammingError, + code='21000', name='CARDINALITY_VIOLATION'): + pass + + +# Class 22 - Data Exception + +class DataException(DataError, + code='22000', name='DATA_EXCEPTION'): + pass + +class StringDataRightTruncation(DataError, + code='22001', name='STRING_DATA_RIGHT_TRUNCATION'): + pass + +class NullValueNoIndicatorParameter(DataError, + code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'): + pass + +class NumericValueOutOfRange(DataError, + code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'): + pass + +class NullValueNotAllowed(DataError, + code='22004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class ErrorInAssignment(DataError, + code='22005', name='ERROR_IN_ASSIGNMENT'): + pass + +class InvalidDatetimeFormat(DataError, + code='22007', name='INVALID_DATETIME_FORMAT'): + pass + +class DatetimeFieldOverflow(DataError, + code='22008', name='DATETIME_FIELD_OVERFLOW'): + pass + +class InvalidTimeZoneDisplacementValue(DataError, + code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'): + pass + +class EscapeCharacterConflict(DataError, + code='2200B', name='ESCAPE_CHARACTER_CONFLICT'): + pass + +class InvalidUseOfEscapeCharacter(DataError, + code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'): + pass + +class InvalidEscapeOctet(DataError, + code='2200D', name='INVALID_ESCAPE_OCTET'): + pass + +class ZeroLengthCharacterString(DataError, + code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'): + pass + +class MostSpecificTypeMismatch(DataError, + code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'): + pass + +class SequenceGeneratorLimitExceeded(DataError, + code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'): + pass + +class NotAnXmlDocument(DataError, + code='2200L', name='NOT_AN_XML_DOCUMENT'): + pass + +class InvalidXmlDocument(DataError, + code='2200M', name='INVALID_XML_DOCUMENT'): + pass + +class InvalidXmlContent(DataError, + code='2200N', name='INVALID_XML_CONTENT'): + pass + +class InvalidXmlComment(DataError, + code='2200S', name='INVALID_XML_COMMENT'): + pass + +class InvalidXmlProcessingInstruction(DataError, + code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'): + pass + +class InvalidIndicatorParameterValue(DataError, + code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'): + pass + +class SubstringError(DataError, + code='22011', name='SUBSTRING_ERROR'): + pass + +class DivisionByZero(DataError, + code='22012', name='DIVISION_BY_ZERO'): + pass + +class InvalidPrecedingOrFollowingSize(DataError, + code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'): + pass + +class InvalidArgumentForNtileFunction(DataError, + code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'): + pass + +class IntervalFieldOverflow(DataError, + code='22015', name='INTERVAL_FIELD_OVERFLOW'): + pass + +class InvalidArgumentForNthValueFunction(DataError, + code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'): + pass + +class InvalidCharacterValueForCast(DataError, + code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'): + pass + +class InvalidEscapeCharacter(DataError, + code='22019', name='INVALID_ESCAPE_CHARACTER'): + pass + +class InvalidRegularExpression(DataError, + code='2201B', name='INVALID_REGULAR_EXPRESSION'): + pass + +class InvalidArgumentForLogarithm(DataError, + code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'): + pass + +class InvalidArgumentForPowerFunction(DataError, + code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'): + pass + +class InvalidArgumentForWidthBucketFunction(DataError, + code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'): + pass + +class InvalidRowCountInLimitClause(DataError, + code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'): + pass + +class InvalidRowCountInResultOffsetClause(DataError, + code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'): + pass + +class CharacterNotInRepertoire(DataError, + code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'): + pass + +class IndicatorOverflow(DataError, + code='22022', name='INDICATOR_OVERFLOW'): + pass + +class InvalidParameterValue(DataError, + code='22023', name='INVALID_PARAMETER_VALUE'): + pass + +class UnterminatedCString(DataError, + code='22024', name='UNTERMINATED_C_STRING'): + pass + +class InvalidEscapeSequence(DataError, + code='22025', name='INVALID_ESCAPE_SEQUENCE'): + pass + +class StringDataLengthMismatch(DataError, + code='22026', name='STRING_DATA_LENGTH_MISMATCH'): + pass + +class TrimError(DataError, + code='22027', name='TRIM_ERROR'): + pass + +class ArraySubscriptError(DataError, + code='2202E', name='ARRAY_SUBSCRIPT_ERROR'): + pass + +class InvalidTablesampleRepeat(DataError, + code='2202G', name='INVALID_TABLESAMPLE_REPEAT'): + pass + +class InvalidTablesampleArgument(DataError, + code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'): + pass + +class DuplicateJsonObjectKeyValue(DataError, + code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'): + pass + +class InvalidArgumentForSqlJsonDatetimeFunction(DataError, + code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'): + pass + +class InvalidJsonText(DataError, + code='22032', name='INVALID_JSON_TEXT'): + pass + +class InvalidSqlJsonSubscript(DataError, + code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'): + pass + +class MoreThanOneSqlJsonItem(DataError, + code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'): + pass + +class NoSqlJsonItem(DataError, + code='22035', name='NO_SQL_JSON_ITEM'): + pass + +class NonNumericSqlJsonItem(DataError, + code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'): + pass + +class NonUniqueKeysInAJsonObject(DataError, + code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'): + pass + +class SingletonSqlJsonItemRequired(DataError, + code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'): + pass + +class SqlJsonArrayNotFound(DataError, + code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'): + pass + +class SqlJsonMemberNotFound(DataError, + code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'): + pass + +class SqlJsonNumberNotFound(DataError, + code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'): + pass + +class SqlJsonObjectNotFound(DataError, + code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'): + pass + +class TooManyJsonArrayElements(DataError, + code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'): + pass + +class TooManyJsonObjectMembers(DataError, + code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'): + pass + +class SqlJsonScalarRequired(DataError, + code='2203F', name='SQL_JSON_SCALAR_REQUIRED'): + pass + +class SqlJsonItemCannotBeCastToTargetType(DataError, + code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'): + pass + +class FloatingPointException(DataError, + code='22P01', name='FLOATING_POINT_EXCEPTION'): + pass + +class InvalidTextRepresentation(DataError, + code='22P02', name='INVALID_TEXT_REPRESENTATION'): + pass + +class InvalidBinaryRepresentation(DataError, + code='22P03', name='INVALID_BINARY_REPRESENTATION'): + pass + +class BadCopyFileFormat(DataError, + code='22P04', name='BAD_COPY_FILE_FORMAT'): + pass + +class UntranslatableCharacter(DataError, + code='22P05', name='UNTRANSLATABLE_CHARACTER'): + pass + +class NonstandardUseOfEscapeCharacter(DataError, + code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'): + pass + + +# Class 23 - Integrity Constraint Violation + +class IntegrityConstraintViolation(IntegrityError, + code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class RestrictViolation(IntegrityError, + code='23001', name='RESTRICT_VIOLATION'): + pass + +class NotNullViolation(IntegrityError, + code='23502', name='NOT_NULL_VIOLATION'): + pass + +class ForeignKeyViolation(IntegrityError, + code='23503', name='FOREIGN_KEY_VIOLATION'): + pass + +class UniqueViolation(IntegrityError, + code='23505', name='UNIQUE_VIOLATION'): + pass + +class CheckViolation(IntegrityError, + code='23514', name='CHECK_VIOLATION'): + pass + +class ExclusionViolation(IntegrityError, + code='23P01', name='EXCLUSION_VIOLATION'): + pass + + +# Class 24 - Invalid Cursor State + +class InvalidCursorState(InternalError, + code='24000', name='INVALID_CURSOR_STATE'): + pass + + +# Class 25 - Invalid Transaction State + +class InvalidTransactionState(InternalError, + code='25000', name='INVALID_TRANSACTION_STATE'): + pass + +class ActiveSqlTransaction(InternalError, + code='25001', name='ACTIVE_SQL_TRANSACTION'): + pass + +class BranchTransactionAlreadyActive(InternalError, + code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'): + pass + +class InappropriateAccessModeForBranchTransaction(InternalError, + code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'): + pass + +class InappropriateIsolationLevelForBranchTransaction(InternalError, + code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'): + pass + +class NoActiveSqlTransactionForBranchTransaction(InternalError, + code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'): + pass + +class ReadOnlySqlTransaction(InternalError, + code='25006', name='READ_ONLY_SQL_TRANSACTION'): + pass + +class SchemaAndDataStatementMixingNotSupported(InternalError, + code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'): + pass + +class HeldCursorRequiresSameIsolationLevel(InternalError, + code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'): + pass + +class NoActiveSqlTransaction(InternalError, + code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'): + pass + +class InFailedSqlTransaction(InternalError, + code='25P02', name='IN_FAILED_SQL_TRANSACTION'): + pass + +class IdleInTransactionSessionTimeout(InternalError, + code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'): + pass + + +# Class 26 - Invalid SQL Statement Name + +class InvalidSqlStatementName(ProgrammingError, + code='26000', name='INVALID_SQL_STATEMENT_NAME'): + pass + + +# Class 27 - Triggered Data Change Violation + +class TriggeredDataChangeViolation(OperationalError, + code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'): + pass + + +# Class 28 - Invalid Authorization Specification + +class InvalidAuthorizationSpecification(OperationalError, + code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'): + pass + +class InvalidPassword(OperationalError, + code='28P01', name='INVALID_PASSWORD'): + pass + + +# Class 2B - Dependent Privilege Descriptors Still Exist + +class DependentPrivilegeDescriptorsStillExist(InternalError, + code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'): + pass + +class DependentObjectsStillExist(InternalError, + code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'): + pass + + +# Class 2D - Invalid Transaction Termination + +class InvalidTransactionTermination(InternalError, + code='2D000', name='INVALID_TRANSACTION_TERMINATION'): + pass + + +# Class 2F - SQL Routine Exception + +class SqlRoutineException(OperationalError, + code='2F000', name='SQL_ROUTINE_EXCEPTION'): + pass + +class ModifyingSqlDataNotPermitted(OperationalError, + code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttempted(OperationalError, + code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermitted(OperationalError, + code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + +class FunctionExecutedNoReturnStatement(OperationalError, + code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'): + pass + + +# Class 34 - Invalid Cursor Name + +class InvalidCursorName(ProgrammingError, + code='34000', name='INVALID_CURSOR_NAME'): + pass + + +# Class 38 - External Routine Exception + +class ExternalRoutineException(OperationalError, + code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'): + pass + +class ContainingSqlNotPermitted(OperationalError, + code='38001', name='CONTAINING_SQL_NOT_PERMITTED'): + pass + +class ModifyingSqlDataNotPermittedExt(OperationalError, + code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttemptedExt(OperationalError, + code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermittedExt(OperationalError, + code='38004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + + +# Class 39 - External Routine Invocation Exception + +class ExternalRoutineInvocationException(OperationalError, + code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'): + pass + +class InvalidSqlstateReturned(OperationalError, + code='39001', name='INVALID_SQLSTATE_RETURNED'): + pass + +class NullValueNotAllowedExt(OperationalError, + code='39004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class TriggerProtocolViolated(OperationalError, + code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'): + pass + +class SrfProtocolViolated(OperationalError, + code='39P02', name='SRF_PROTOCOL_VIOLATED'): + pass + +class EventTriggerProtocolViolated(OperationalError, + code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'): + pass + + +# Class 3B - Savepoint Exception + +class SavepointException(OperationalError, + code='3B000', name='SAVEPOINT_EXCEPTION'): + pass + +class InvalidSavepointSpecification(OperationalError, + code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'): + pass + + +# Class 3D - Invalid Catalog Name + +class InvalidCatalogName(ProgrammingError, + code='3D000', name='INVALID_CATALOG_NAME'): + pass + + +# Class 3F - Invalid Schema Name + +class InvalidSchemaName(ProgrammingError, + code='3F000', name='INVALID_SCHEMA_NAME'): + pass + + +# Class 40 - Transaction Rollback + +class TransactionRollback(OperationalError, + code='40000', name='TRANSACTION_ROLLBACK'): + pass + +class SerializationFailure(OperationalError, + code='40001', name='SERIALIZATION_FAILURE'): + pass + +class TransactionIntegrityConstraintViolation(OperationalError, + code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class StatementCompletionUnknown(OperationalError, + code='40003', name='STATEMENT_COMPLETION_UNKNOWN'): + pass + +class DeadlockDetected(OperationalError, + code='40P01', name='DEADLOCK_DETECTED'): + pass + + +# Class 42 - Syntax Error or Access Rule Violation + +class SyntaxErrorOrAccessRuleViolation(ProgrammingError, + code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'): + pass + +class InsufficientPrivilege(ProgrammingError, + code='42501', name='INSUFFICIENT_PRIVILEGE'): + pass + +class SyntaxError(ProgrammingError, + code='42601', name='SYNTAX_ERROR'): + pass + +class InvalidName(ProgrammingError, + code='42602', name='INVALID_NAME'): + pass + +class InvalidColumnDefinition(ProgrammingError, + code='42611', name='INVALID_COLUMN_DEFINITION'): + pass + +class NameTooLong(ProgrammingError, + code='42622', name='NAME_TOO_LONG'): + pass + +class DuplicateColumn(ProgrammingError, + code='42701', name='DUPLICATE_COLUMN'): + pass + +class AmbiguousColumn(ProgrammingError, + code='42702', name='AMBIGUOUS_COLUMN'): + pass + +class UndefinedColumn(ProgrammingError, + code='42703', name='UNDEFINED_COLUMN'): + pass + +class UndefinedObject(ProgrammingError, + code='42704', name='UNDEFINED_OBJECT'): + pass + +class DuplicateObject(ProgrammingError, + code='42710', name='DUPLICATE_OBJECT'): + pass + +class DuplicateAlias(ProgrammingError, + code='42712', name='DUPLICATE_ALIAS'): + pass + +class DuplicateFunction(ProgrammingError, + code='42723', name='DUPLICATE_FUNCTION'): + pass + +class AmbiguousFunction(ProgrammingError, + code='42725', name='AMBIGUOUS_FUNCTION'): + pass + +class GroupingError(ProgrammingError, + code='42803', name='GROUPING_ERROR'): + pass + +class DatatypeMismatch(ProgrammingError, + code='42804', name='DATATYPE_MISMATCH'): + pass + +class WrongObjectType(ProgrammingError, + code='42809', name='WRONG_OBJECT_TYPE'): + pass + +class InvalidForeignKey(ProgrammingError, + code='42830', name='INVALID_FOREIGN_KEY'): + pass + +class CannotCoerce(ProgrammingError, + code='42846', name='CANNOT_COERCE'): + pass + +class UndefinedFunction(ProgrammingError, + code='42883', name='UNDEFINED_FUNCTION'): + pass + +class GeneratedAlways(ProgrammingError, + code='428C9', name='GENERATED_ALWAYS'): + pass + +class ReservedName(ProgrammingError, + code='42939', name='RESERVED_NAME'): + pass + +class UndefinedTable(ProgrammingError, + code='42P01', name='UNDEFINED_TABLE'): + pass + +class UndefinedParameter(ProgrammingError, + code='42P02', name='UNDEFINED_PARAMETER'): + pass + +class DuplicateCursor(ProgrammingError, + code='42P03', name='DUPLICATE_CURSOR'): + pass + +class DuplicateDatabase(ProgrammingError, + code='42P04', name='DUPLICATE_DATABASE'): + pass + +class DuplicatePreparedStatement(ProgrammingError, + code='42P05', name='DUPLICATE_PREPARED_STATEMENT'): + pass + +class DuplicateSchema(ProgrammingError, + code='42P06', name='DUPLICATE_SCHEMA'): + pass + +class DuplicateTable(ProgrammingError, + code='42P07', name='DUPLICATE_TABLE'): + pass + +class AmbiguousParameter(ProgrammingError, + code='42P08', name='AMBIGUOUS_PARAMETER'): + pass + +class AmbiguousAlias(ProgrammingError, + code='42P09', name='AMBIGUOUS_ALIAS'): + pass + +class InvalidColumnReference(ProgrammingError, + code='42P10', name='INVALID_COLUMN_REFERENCE'): + pass + +class InvalidCursorDefinition(ProgrammingError, + code='42P11', name='INVALID_CURSOR_DEFINITION'): + pass + +class InvalidDatabaseDefinition(ProgrammingError, + code='42P12', name='INVALID_DATABASE_DEFINITION'): + pass + +class InvalidFunctionDefinition(ProgrammingError, + code='42P13', name='INVALID_FUNCTION_DEFINITION'): + pass + +class InvalidPreparedStatementDefinition(ProgrammingError, + code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'): + pass + +class InvalidSchemaDefinition(ProgrammingError, + code='42P15', name='INVALID_SCHEMA_DEFINITION'): + pass + +class InvalidTableDefinition(ProgrammingError, + code='42P16', name='INVALID_TABLE_DEFINITION'): + pass + +class InvalidObjectDefinition(ProgrammingError, + code='42P17', name='INVALID_OBJECT_DEFINITION'): + pass + +class IndeterminateDatatype(ProgrammingError, + code='42P18', name='INDETERMINATE_DATATYPE'): + pass + +class InvalidRecursion(ProgrammingError, + code='42P19', name='INVALID_RECURSION'): + pass + +class WindowingError(ProgrammingError, + code='42P20', name='WINDOWING_ERROR'): + pass + +class CollationMismatch(ProgrammingError, + code='42P21', name='COLLATION_MISMATCH'): + pass + +class IndeterminateCollation(ProgrammingError, + code='42P22', name='INDETERMINATE_COLLATION'): + pass + + +# Class 44 - WITH CHECK OPTION Violation + +class WithCheckOptionViolation(ProgrammingError, + code='44000', name='WITH_CHECK_OPTION_VIOLATION'): + pass + + +# Class 53 - Insufficient Resources + +class InsufficientResources(OperationalError, + code='53000', name='INSUFFICIENT_RESOURCES'): + pass + +class DiskFull(OperationalError, + code='53100', name='DISK_FULL'): + pass + +class OutOfMemory(OperationalError, + code='53200', name='OUT_OF_MEMORY'): + pass + +class TooManyConnections(OperationalError, + code='53300', name='TOO_MANY_CONNECTIONS'): + pass + +class ConfigurationLimitExceeded(OperationalError, + code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'): + pass + + +# Class 54 - Program Limit Exceeded + +class ProgramLimitExceeded(OperationalError, + code='54000', name='PROGRAM_LIMIT_EXCEEDED'): + pass + +class StatementTooComplex(OperationalError, + code='54001', name='STATEMENT_TOO_COMPLEX'): + pass + +class TooManyColumns(OperationalError, + code='54011', name='TOO_MANY_COLUMNS'): + pass + +class TooManyArguments(OperationalError, + code='54023', name='TOO_MANY_ARGUMENTS'): + pass + + +# Class 55 - Object Not In Prerequisite State + +class ObjectNotInPrerequisiteState(OperationalError, + code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'): + pass + +class ObjectInUse(OperationalError, + code='55006', name='OBJECT_IN_USE'): + pass + +class CantChangeRuntimeParam(OperationalError, + code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'): + pass + +class LockNotAvailable(OperationalError, + code='55P03', name='LOCK_NOT_AVAILABLE'): + pass + +class UnsafeNewEnumValueUsage(OperationalError, + code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'): + pass + + +# Class 57 - Operator Intervention + +class OperatorIntervention(OperationalError, + code='57000', name='OPERATOR_INTERVENTION'): + pass + +class QueryCanceled(OperationalError, + code='57014', name='QUERY_CANCELED'): + pass + +class AdminShutdown(OperationalError, + code='57P01', name='ADMIN_SHUTDOWN'): + pass + +class CrashShutdown(OperationalError, + code='57P02', name='CRASH_SHUTDOWN'): + pass + +class CannotConnectNow(OperationalError, + code='57P03', name='CANNOT_CONNECT_NOW'): + pass + +class DatabaseDropped(OperationalError, + code='57P04', name='DATABASE_DROPPED'): + pass + +class IdleSessionTimeout(OperationalError, + code='57P05', name='IDLE_SESSION_TIMEOUT'): + pass + + +# Class 58 - System Error (errors external to PostgreSQL itself) + +class SystemError(OperationalError, + code='58000', name='SYSTEM_ERROR'): + pass + +class IoError(OperationalError, + code='58030', name='IO_ERROR'): + pass + +class UndefinedFile(OperationalError, + code='58P01', name='UNDEFINED_FILE'): + pass + +class DuplicateFile(OperationalError, + code='58P02', name='DUPLICATE_FILE'): + pass + + +# Class 72 - Snapshot Failure + +class SnapshotTooOld(DatabaseError, + code='72000', name='SNAPSHOT_TOO_OLD'): + pass + + +# Class F0 - Configuration File Error + +class ConfigFileError(OperationalError, + code='F0000', name='CONFIG_FILE_ERROR'): + pass + +class LockFileExists(OperationalError, + code='F0001', name='LOCK_FILE_EXISTS'): + pass + + +# Class HV - Foreign Data Wrapper Error (SQL/MED) + +class FdwError(OperationalError, + code='HV000', name='FDW_ERROR'): + pass + +class FdwOutOfMemory(OperationalError, + code='HV001', name='FDW_OUT_OF_MEMORY'): + pass + +class FdwDynamicParameterValueNeeded(OperationalError, + code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'): + pass + +class FdwInvalidDataType(OperationalError, + code='HV004', name='FDW_INVALID_DATA_TYPE'): + pass + +class FdwColumnNameNotFound(OperationalError, + code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'): + pass + +class FdwInvalidDataTypeDescriptors(OperationalError, + code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'): + pass + +class FdwInvalidColumnName(OperationalError, + code='HV007', name='FDW_INVALID_COLUMN_NAME'): + pass + +class FdwInvalidColumnNumber(OperationalError, + code='HV008', name='FDW_INVALID_COLUMN_NUMBER'): + pass + +class FdwInvalidUseOfNullPointer(OperationalError, + code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'): + pass + +class FdwInvalidStringFormat(OperationalError, + code='HV00A', name='FDW_INVALID_STRING_FORMAT'): + pass + +class FdwInvalidHandle(OperationalError, + code='HV00B', name='FDW_INVALID_HANDLE'): + pass + +class FdwInvalidOptionIndex(OperationalError, + code='HV00C', name='FDW_INVALID_OPTION_INDEX'): + pass + +class FdwInvalidOptionName(OperationalError, + code='HV00D', name='FDW_INVALID_OPTION_NAME'): + pass + +class FdwOptionNameNotFound(OperationalError, + code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'): + pass + +class FdwReplyHandle(OperationalError, + code='HV00K', name='FDW_REPLY_HANDLE'): + pass + +class FdwUnableToCreateExecution(OperationalError, + code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'): + pass + +class FdwUnableToCreateReply(OperationalError, + code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'): + pass + +class FdwUnableToEstablishConnection(OperationalError, + code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'): + pass + +class FdwNoSchemas(OperationalError, + code='HV00P', name='FDW_NO_SCHEMAS'): + pass + +class FdwSchemaNotFound(OperationalError, + code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'): + pass + +class FdwTableNotFound(OperationalError, + code='HV00R', name='FDW_TABLE_NOT_FOUND'): + pass + +class FdwFunctionSequenceError(OperationalError, + code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'): + pass + +class FdwTooManyHandles(OperationalError, + code='HV014', name='FDW_TOO_MANY_HANDLES'): + pass + +class FdwInconsistentDescriptorInformation(OperationalError, + code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'): + pass + +class FdwInvalidAttributeValue(OperationalError, + code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'): + pass + +class FdwInvalidStringLengthOrBufferLength(OperationalError, + code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'): + pass + +class FdwInvalidDescriptorFieldIdentifier(OperationalError, + code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'): + pass + + +# Class P0 - PL/pgSQL Error + +class PlpgsqlError(ProgrammingError, + code='P0000', name='PLPGSQL_ERROR'): + pass + +class RaiseException(ProgrammingError, + code='P0001', name='RAISE_EXCEPTION'): + pass + +class NoDataFound(ProgrammingError, + code='P0002', name='NO_DATA_FOUND'): + pass + +class TooManyRows(ProgrammingError, + code='P0003', name='TOO_MANY_ROWS'): + pass + +class AssertFailure(ProgrammingError, + code='P0004', name='ASSERT_FAILURE'): + pass + + +# Class XX - Internal Error + +class InternalError_(InternalError, + code='XX000', name='INTERNAL_ERROR'): + pass + +class DataCorrupted(InternalError, + code='XX001', name='DATA_CORRUPTED'): + pass + +class IndexCorrupted(InternalError, + code='XX002', name='INDEX_CORRUPTED'): + pass + + +# autogenerated: end +# fmt: on diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py new file mode 100644 index 0000000..584fe47 --- /dev/null +++ b/psycopg/psycopg/generators.py @@ -0,0 +1,320 @@ +""" +Generators implementing communication protocols with the libpq + +Certain operations (connection, querying) are an interleave of libpq calls and +waiting for the socket to be ready. This module contains the code to execute +the operations, yielding a polling state whenever there is to wait. The +functions in the `waiting` module are the ones who wait more or less +cooperatively for the socket to be ready and make these generators continue. + +All these generators yield pairs (fileno, `Wait`) whenever an operation would +block. The generator can be restarted sending the appropriate `Ready` state +when the file descriptor is ready. + +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import List, Optional, Union + +from . import pq +from . import errors as e +from .abc import Buffer, PipelineCommand, PQGen, PQGenConn +from .pq.abc import PGconn, PGresult +from .waiting import Wait, Ready +from ._compat import Deque +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding, conninfo_encoding + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +POLL_OK = pq.PollingStatus.OK +POLL_READING = pq.PollingStatus.READING +POLL_WRITING = pq.PollingStatus.WRITING +POLL_FAILED = pq.PollingStatus.FAILED + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + +logger = logging.getLogger(__name__) + + +def _connect(conninfo: str) -> PQGenConn[PGconn]: + """ + Generator to create a database connection without blocking. + + """ + conn = pq.PGconn.connect_start(conninfo.encode()) + while True: + if conn.status == BAD: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection is bad: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + + status = conn.connect_poll() + if status == POLL_OK: + break + elif status == POLL_READING: + yield conn.socket, WAIT_R + elif status == POLL_WRITING: + yield conn.socket, WAIT_W + elif status == POLL_FAILED: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection failed: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + else: + raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn) + + conn.nonblocking = 1 + return conn + + +def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator sending a query and returning results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Return the list of results returned by the database (whether success + or error). + """ + yield from _send(pgconn) + rv = yield from _fetch_many(pgconn) + return rv + + +def _send(pgconn: PGconn) -> PQGen[None]: + """ + Generator to send a query to the server without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + After this generator has finished you may want to cycle using `fetch()` + to retrieve the results available. + """ + while True: + f = pgconn.flush() + if f == 0: + break + + ready = yield WAIT_RW + if ready & READY_R: + # This call may read notifies: they will be saved in the + # PGconn buffer and passed to Python later, in `fetch()`. + pgconn.consume_input() + + +def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator retrieving results from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return the list of results returned by the database (whether success + or error). + """ + results: List[PGresult] = [] + while True: + res = yield from _fetch(pgconn) + if not res: + break + + results.append(res) + status = res.status + if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + # After entering copy mode the libpq will create a phony result + # for every request so let's break the endless loop. + break + + if status == PIPELINE_SYNC: + # PIPELINE_SYNC is not followed by a NULL, but we return it alone + # similarly to other result sets. + assert len(results) == 1, results + break + + return results + + +def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]: + """ + Generator retrieving a single result from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return a result from the database (whether success or error). + """ + if pgconn.is_busy(): + yield WAIT_R + while True: + pgconn.consume_input() + if not pgconn.is_busy(): + break + yield WAIT_R + + _consume_notifies(pgconn) + + return pgconn.get_result() + + +def _pipeline_communicate( + pgconn: PGconn, commands: Deque[PipelineCommand] +) -> PQGen[List[List[PGresult]]]: + """Generator to send queries from a connection in pipeline mode while also + receiving results. + + Return a list results, including single PIPELINE_SYNC elements. + """ + results = [] + + while True: + ready = yield WAIT_RW + + if ready & READY_R: + pgconn.consume_input() + _consume_notifies(pgconn) + + res: List[PGresult] = [] + while not pgconn.is_busy(): + r = pgconn.get_result() + if r is None: + if not res: + break + results.append(res) + res = [] + elif r.status == PIPELINE_SYNC: + assert not res + results.append([r]) + else: + res.append(r) + + if ready & READY_W: + pgconn.flush() + if not commands: + break + commands.popleft()() + + return results + + +def _consume_notifies(pgconn: PGconn) -> None: + # Consume notifies + while True: + n = pgconn.notifies() + if not n: + break + if pgconn.notify_handler: + pgconn.notify_handler(n) + + +def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: + yield WAIT_R + pgconn.consume_input() + + ns = [] + while True: + n = pgconn.notifies() + if n: + ns.append(n) + else: + break + + return ns + + +def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]: + while True: + nbytes, data = pgconn.get_copy_data(1) + if nbytes != 0: + break + + # would block + yield WAIT_R + pgconn.consume_input() + + if nbytes > 0: + # some data + return data + + # Retrieve the final result of copy + results = yield from _fetch_many(pgconn) + if len(results) > 1: + # TODO: too brutal? Copy worked. + raise e.ProgrammingError("you cannot mix COPY with other operations") + result = results[0] + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]: + # Retry enqueuing data until successful. + # + # WARNING! This can cause an infinite loop if the buffer is too large. (see + # ticket #255). We avoid it in the Copy object by splitting a large buffer + # into smaller ones. We prefer to do it there instead of here in order to + # do it upstream the queue decoupling the writer task from the producer one. + while pgconn.put_copy_data(buffer) == 0: + yield WAIT_W + + +def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: + # Retry enqueuing end copy message until successful + while pgconn.put_copy_end(error) == 0: + yield WAIT_W + + # Repeat until it the message is flushed to the server + while True: + yield WAIT_W + f = pgconn.flush() + if f == 0: + break + + # Retrieve the final result of copy + (result,) = yield from _fetch_many(pgconn) + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +# Override functions with fast versions if available +if _psycopg: + connect = _psycopg.connect + execute = _psycopg.execute + send = _psycopg.send + fetch_many = _psycopg.fetch_many + fetch = _psycopg.fetch + pipeline_communicate = _psycopg.pipeline_communicate + +else: + connect = _connect + execute = _execute + send = _send + fetch_many = _fetch_many + fetch = _fetch + pipeline_communicate = _pipeline_communicate diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py new file mode 100644 index 0000000..792a9c8 --- /dev/null +++ b/psycopg/psycopg/postgres.py @@ -0,0 +1,125 @@ +""" +Types configuration specific to PostgreSQL. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry +from .abc import AdaptContext +from ._adapters_map import AdaptersMap + +# Global objects with PostgreSQL builtins and globally registered user types. +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + +# Use tools/update_oids.py to update this data. +for t in [ + TypeInfo('"char"', 18, 1002), + # autogenerated: start + # Generated from PostgreSQL 15.0 + TypeInfo("aclitem", 1033, 1034), + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("box", 603, 1020, delimiter=";"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("cid", 29, 1012), + TypeInfo("cidr", 650, 651), + TypeInfo("circle", 718, 719), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("gtsvector", 3642, 3644), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007, regtype="integer"), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("json", 114, 199), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("jsonpath", 4072, 4073), + TypeInfo("line", 628, 629), + TypeInfo("lseg", 601, 1018), + TypeInfo("macaddr", 829, 1040), + TypeInfo("macaddr8", 774, 775), + TypeInfo("money", 790, 791), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("path", 602, 1019), + TypeInfo("pg_lsn", 3220, 3221), + TypeInfo("point", 600, 1017), + TypeInfo("polygon", 604, 1027), + TypeInfo("record", 2249, 2287), + TypeInfo("refcursor", 1790, 2201), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regcollation", 4191, 4192), + TypeInfo("regconfig", 3734, 3735), + TypeInfo("regdictionary", 3769, 3770), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regoper", 2203, 2208), + TypeInfo("regoperator", 2204, 2209), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("tid", 27, 1010), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("tsquery", 3615, 3645), + TypeInfo("tsvector", 3614, 3643), + TypeInfo("txid_snapshot", 2970, 2949), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + TypeInfo("xid", 28, 1011), + TypeInfo("xid8", 5069, 271), + TypeInfo("xml", 142, 143), + RangeInfo("daterange", 3912, 3913, subtype_oid=1082), + RangeInfo("int4range", 3904, 3905, subtype_oid=23), + RangeInfo("int8range", 3926, 3927, subtype_oid=20), + RangeInfo("numrange", 3906, 3907, subtype_oid=1700), + RangeInfo("tsrange", 3908, 3909, subtype_oid=1114), + RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184), + MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082), + MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23), + MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20), + MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700), + MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114), + MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184), + # autogenerated: end +]: + types.add(t) + + +# A few oids used a bit everywhere +INVALID_OID = 0 +TEXT_OID = types["text"].oid +TEXT_ARRAY_OID = types["text"].array_oid + + +def register_default_adapters(context: AdaptContext) -> None: + + from .types import array, bool, composite, datetime, enum, json, multirange + from .types import net, none, numeric, range, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + enum.register_default_adapters(context) + json.register_default_adapters(context) + multirange.register_default_adapters(context) + net.register_default_adapters(context) + none.register_default_adapters(context) + numeric.register_default_adapters(context) + range.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py new file mode 100644 index 0000000..d5180b1 --- /dev/null +++ b/psycopg/psycopg/pq/__init__.py @@ -0,0 +1,133 @@ +""" +psycopg libpq wrapper + +This package exposes the libpq functionalities as Python objects and functions. + +The real implementation (the binding to the C library) is +implementation-dependant but all the implementations share the same interface. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import logging +from typing import Callable, List, Type + +from . import abc +from .misc import ConninfoOption, PGnotify, PGresAttDesc +from .misc import error_message +from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace +from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus + +logger = logging.getLogger(__name__) + +__impl__: str +"""The currently loaded implementation of the `!psycopg.pq` package. + +Possible values include ``python``, ``c``, ``binary``. +""" + +__build_version__: int +"""The libpq version the C package was built with. + +A number in the same format of `~psycopg.ConnectionInfo.server_version` +representing the libpq used to build the speedup module (``c``, ``binary``) if +available. + +Certain features might not be available if the built version is too old. +""" + +version: Callable[[], int] +PGconn: Type[abc.PGconn] +PGresult: Type[abc.PGresult] +Conninfo: Type[abc.Conninfo] +Escaping: Type[abc.Escaping] +PGcancel: Type[abc.PGcancel] + + +def import_from_libpq() -> None: + """ + Import pq objects implementation from the best libpq wrapper available. + + If an implementation is requested try to import only it, otherwise + try to import the best implementation available. + """ + # import these names into the module on success as side effect + global __impl__, version, __build_version__ + global PGconn, PGresult, Conninfo, Escaping, PGcancel + + impl = os.environ.get("PSYCOPG_IMPL", "").lower() + module = None + attempts: List[str] = [] + + def handle_error(name: str, e: Exception) -> None: + if not impl: + msg = f"couldn't import psycopg '{name}' implementation: {e}" + logger.debug(msg) + attempts.append(msg) + else: + msg = f"couldn't import requested psycopg '{name}' implementation: {e}" + raise ImportError(msg) from e + + # The best implementation: fast but requires the system libpq installed + if not impl or impl == "c": + try: + from psycopg_c import pq as module # type: ignore + except Exception as e: + handle_error("c", e) + + # Second best implementation: fast and stand-alone + if not module and (not impl or impl == "binary"): + try: + from psycopg_binary import pq as module # type: ignore + except Exception as e: + handle_error("binary", e) + + # Pure Python implementation, slow and requires the system libpq installed. + if not module and (not impl or impl == "python"): + try: + from . import pq_ctypes as module # type: ignore[no-redef] + except Exception as e: + handle_error("python", e) + + if module: + __impl__ = module.__impl__ + version = module.version + PGconn = module.PGconn + PGresult = module.PGresult + Conninfo = module.Conninfo + Escaping = module.Escaping + PGcancel = module.PGcancel + __build_version__ = module.__build_version__ + elif impl: + raise ImportError(f"requested psycopg implementation '{impl}' unknown") + else: + sattempts = "\n".join(f"- {attempt}" for attempt in attempts) + raise ImportError( + f"""\ +no pq wrapper available. +Attempts made: +{sattempts}""" + ) + + +import_from_libpq() + +__all__ = ( + "ConnStatus", + "PipelineStatus", + "PollingStatus", + "TransactionStatus", + "ExecStatus", + "Ping", + "DiagnosticField", + "Format", + "Trace", + "PGconn", + "PGnotify", + "Conninfo", + "PGresAttDesc", + "error_message", + "ConninfoOption", + "version", +) diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py new file mode 100644 index 0000000..f35d09f --- /dev/null +++ b/psycopg/psycopg/pq/_debug.py @@ -0,0 +1,106 @@ +""" +libpq debugging tools + +These functionalities are exposed here for convenience, but are not part of +the public interface and are subject to change at any moment. + +Suggested usage:: + + import logging + import psycopg + from psycopg import pq + from psycopg.pq._debug import PGconnDebug + + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + + assert pq.__impl__ == "python" + pq.PGconn = PGconnDebug + + with psycopg.connect("") as conn: + conn.pgconn.trace(2) + conn.pgconn.set_trace_flags( + pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + ... + +""" + +# Copyright (C) 2022 The Psycopg Team + +import inspect +import logging +from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING +from functools import wraps + +from . import PGconn +from .misc import connection_summary + +if TYPE_CHECKING: + from . import abc + +Func = TypeVar("Func", bound=Callable[..., Any]) + +logger = logging.getLogger("psycopg.debug") + + +class PGconnDebug: + """Wrapper for a PQconn logging all its access.""" + + _Self = TypeVar("_Self", bound="PGconnDebug") + _pgconn: "abc.PGconn" + + def __init__(self, pgconn: "abc.PGconn"): + super().__setattr__("_pgconn", pgconn) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def __getattr__(self, attr: str) -> Any: + value = getattr(self._pgconn, attr) + if callable(value): + return debugging(value) + else: + logger.info("PGconn.%s -> %s", attr, value) + return value + + def __setattr__(self, attr: str, value: Any) -> None: + setattr(self._pgconn, attr, value) + logger.info("PGconn.%s <- %s", attr, value) + + @classmethod + def connect(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect)(conninfo)) + + @classmethod + def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect_start)(conninfo)) + + @classmethod + def ping(self, conninfo: bytes) -> int: + return debugging(PGconn.ping)(conninfo) + + +def debugging(f: Func) -> Func: + """Wrap a function in order to log its arguments and return value on call.""" + + @wraps(f) + def debugging_(*args: Any, **kwargs: Any) -> Any: + reprs = [] + for arg in args: + reprs.append(f"{arg!r}") + for (k, v) in kwargs.items(): + reprs.append(f"{k}={v!r}") + + logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs)) + rv = f(*args, **kwargs) + # Display the return value only if the function is declared to return + # something else than None. + ra = inspect.signature(f).return_annotation + if ra is not None or rv is not None: + logger.info(" <- %r", rv) + return rv + + return debugging_ # type: ignore diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py new file mode 100644 index 0000000..e0d4018 --- /dev/null +++ b/psycopg/psycopg/pq/_enums.py @@ -0,0 +1,249 @@ +""" +libpq enum definitions for psycopg +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, IntFlag, auto + + +class ConnStatus(IntEnum): + """ + Current status of the connection. + """ + + __module__ = "psycopg.pq" + + OK = 0 + """The connection is in a working state.""" + BAD = auto() + """The connection is closed.""" + + STARTED = auto() + MADE = auto() + AWAITING_RESPONSE = auto() + AUTH_OK = auto() + SETENV = auto() + SSL_STARTUP = auto() + NEEDED = auto() + CHECK_WRITABLE = auto() + CONSUME = auto() + GSS_STARTUP = auto() + CHECK_TARGET = auto() + CHECK_STANDBY = auto() + + +class PollingStatus(IntEnum): + """ + The status of the socket during a connection. + + If ``READING`` or ``WRITING`` you may select before polling again. + """ + + __module__ = "psycopg.pq" + + FAILED = 0 + """Connection attempt failed.""" + READING = auto() + """Will have to wait before reading new data.""" + WRITING = auto() + """Will have to wait before writing new data.""" + OK = auto() + """Connection completed.""" + + ACTIVE = auto() + + +class ExecStatus(IntEnum): + """ + The status of a command. + """ + + __module__ = "psycopg.pq" + + EMPTY_QUERY = 0 + """The string sent to the server was empty.""" + + COMMAND_OK = auto() + """Successful completion of a command returning no data.""" + + TUPLES_OK = auto() + """ + Successful completion of a command returning data (such as a SELECT or SHOW). + """ + + COPY_OUT = auto() + """Copy Out (from server) data transfer started.""" + + COPY_IN = auto() + """Copy In (to server) data transfer started.""" + + BAD_RESPONSE = auto() + """The server's response was not understood.""" + + NONFATAL_ERROR = auto() + """A nonfatal error (a notice or warning) occurred.""" + + FATAL_ERROR = auto() + """A fatal error occurred.""" + + COPY_BOTH = auto() + """ + Copy In/Out (to and from server) data transfer started. + + This feature is currently used only for streaming replication, so this + status should not occur in ordinary applications. + """ + + SINGLE_TUPLE = auto() + """ + The PGresult contains a single result tuple from the current command. + + This status occurs only when single-row mode has been selected for the + query. + """ + + PIPELINE_SYNC = auto() + """ + The PGresult represents a synchronization point in pipeline mode, + requested by PQpipelineSync. + + This status occurs only when pipeline mode has been selected. + """ + + PIPELINE_ABORTED = auto() + """ + The PGresult represents a pipeline that has received an error from the server. + + PQgetResult must be called repeatedly, and each time it will return this + status code until the end of the current pipeline, at which point it will + return PGRES_PIPELINE_SYNC and normal processing can resume. + """ + + +class TransactionStatus(IntEnum): + """ + The transaction status of a connection. + """ + + __module__ = "psycopg.pq" + + IDLE = 0 + """Connection ready, no transaction active.""" + + ACTIVE = auto() + """A command is in progress.""" + + INTRANS = auto() + """Connection idle in an open transaction.""" + + INERROR = auto() + """An error happened in the current transaction.""" + + UNKNOWN = auto() + """Unknown connection state, broken connection.""" + + +class Ping(IntEnum): + """Response from a ping attempt.""" + + __module__ = "psycopg.pq" + + OK = 0 + """ + The server is running and appears to be accepting connections. + """ + + REJECT = auto() + """ + The server is running but is in a state that disallows connections. + """ + + NO_RESPONSE = auto() + """ + The server could not be contacted. + """ + + NO_ATTEMPT = auto() + """ + No attempt was made to contact the server. + """ + + +class PipelineStatus(IntEnum): + """Pipeline mode status of the libpq connection.""" + + __module__ = "psycopg.pq" + + OFF = 0 + """ + The libpq connection is *not* in pipeline mode. + """ + ON = auto() + """ + The libpq connection is in pipeline mode. + """ + ABORTED = auto() + """ + The libpq connection is in pipeline mode and an error occurred while + processing the current pipeline. The aborted flag is cleared when + PQgetResult returns a result of type PGRES_PIPELINE_SYNC. + """ + + +class DiagnosticField(IntEnum): + """ + Fields in an error report. + """ + + __module__ = "psycopg.pq" + + # from postgres_ext.h + SEVERITY = ord("S") + SEVERITY_NONLOCALIZED = ord("V") + SQLSTATE = ord("C") + MESSAGE_PRIMARY = ord("M") + MESSAGE_DETAIL = ord("D") + MESSAGE_HINT = ord("H") + STATEMENT_POSITION = ord("P") + INTERNAL_POSITION = ord("p") + INTERNAL_QUERY = ord("q") + CONTEXT = ord("W") + SCHEMA_NAME = ord("s") + TABLE_NAME = ord("t") + COLUMN_NAME = ord("c") + DATATYPE_NAME = ord("d") + CONSTRAINT_NAME = ord("n") + SOURCE_FILE = ord("F") + SOURCE_LINE = ord("L") + SOURCE_FUNCTION = ord("R") + + +class Format(IntEnum): + """ + Enum representing the format of a query argument or return value. + + These values are only the ones managed by the libpq. `~psycopg` may also + support automatically-chosen values: see `psycopg.adapt.PyFormat`. + """ + + __module__ = "psycopg.pq" + + TEXT = 0 + """Text parameter.""" + BINARY = 1 + """Binary parameter.""" + + +class Trace(IntFlag): + """ + Enum to control tracing of the client/server communication. + """ + + __module__ = "psycopg.pq" + + SUPPRESS_TIMESTAMPS = 1 + """Do not include timestamps in messages.""" + + REGRESS_MODE = 2 + """Redact some fields, e.g. OIDs, from messages.""" diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py new file mode 100644 index 0000000..9ca1d12 --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.py @@ -0,0 +1,804 @@ +""" +libpq access using ctypes +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import ctypes +import ctypes.util +from ctypes import Structure, CFUNCTYPE, POINTER +from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p +from typing import List, Optional, Tuple + +from .misc import find_libpq_full_path +from ..errors import NotSupportedError + +libname = find_libpq_full_path() +if not libname: + raise ImportError("libpq library not found") + +pq = ctypes.cdll.LoadLibrary(libname) + + +class FILE(Structure): + pass + + +FILE_ptr = POINTER(FILE) + +if sys.platform == "linux": + libcname = ctypes.util.find_library("c") + assert libcname + libc = ctypes.cdll.LoadLibrary(libcname) + + fdopen = libc.fdopen + fdopen.argtypes = (c_int, c_char_p) + fdopen.restype = FILE_ptr + + +# Get the libpq version to define what functions are available. + +PQlibVersion = pq.PQlibVersion +PQlibVersion.argtypes = [] +PQlibVersion.restype = c_int + +libpq_version = PQlibVersion() + + +# libpq data types + + +Oid = c_uint + + +class PGconn_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresult_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PQconninfoOption_struct(Structure): + _fields_ = [ + ("keyword", c_char_p), + ("envvar", c_char_p), + ("compiled", c_char_p), + ("val", c_char_p), + ("label", c_char_p), + ("dispchar", c_char_p), + ("dispsize", c_int), + ] + + +class PGnotify_struct(Structure): + _fields_ = [ + ("relname", c_char_p), + ("be_pid", c_int), + ("extra", c_char_p), + ] + + +class PGcancel_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresAttDesc_struct(Structure): + _fields_ = [ + ("name", c_char_p), + ("tableid", Oid), + ("columnid", c_int), + ("format", c_int), + ("typid", Oid), + ("typlen", c_int), + ("atttypmod", c_int), + ] + + +PGconn_ptr = POINTER(PGconn_struct) +PGresult_ptr = POINTER(PGresult_struct) +PQconninfoOption_ptr = POINTER(PQconninfoOption_struct) +PGnotify_ptr = POINTER(PGnotify_struct) +PGcancel_ptr = POINTER(PGcancel_struct) +PGresAttDesc_ptr = POINTER(PGresAttDesc_struct) + + +# Function definitions as explained in PostgreSQL 12 documentation + +# 33.1. Database Connection Control Functions + +# PQconnectdbParams: doesn't seem useful, won't wrap for now + +PQconnectdb = pq.PQconnectdb +PQconnectdb.argtypes = [c_char_p] +PQconnectdb.restype = PGconn_ptr + +# PQsetdbLogin: not useful +# PQsetdb: not useful + +# PQconnectStartParams: not useful + +PQconnectStart = pq.PQconnectStart +PQconnectStart.argtypes = [c_char_p] +PQconnectStart.restype = PGconn_ptr + +PQconnectPoll = pq.PQconnectPoll +PQconnectPoll.argtypes = [PGconn_ptr] +PQconnectPoll.restype = c_int + +PQconndefaults = pq.PQconndefaults +PQconndefaults.argtypes = [] +PQconndefaults.restype = PQconninfoOption_ptr + +PQconninfoFree = pq.PQconninfoFree +PQconninfoFree.argtypes = [PQconninfoOption_ptr] +PQconninfoFree.restype = None + +PQconninfo = pq.PQconninfo +PQconninfo.argtypes = [PGconn_ptr] +PQconninfo.restype = PQconninfoOption_ptr + +PQconninfoParse = pq.PQconninfoParse +PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)] +PQconninfoParse.restype = PQconninfoOption_ptr + +PQfinish = pq.PQfinish +PQfinish.argtypes = [PGconn_ptr] +PQfinish.restype = None + +PQreset = pq.PQreset +PQreset.argtypes = [PGconn_ptr] +PQreset.restype = None + +PQresetStart = pq.PQresetStart +PQresetStart.argtypes = [PGconn_ptr] +PQresetStart.restype = c_int + +PQresetPoll = pq.PQresetPoll +PQresetPoll.argtypes = [PGconn_ptr] +PQresetPoll.restype = c_int + +PQping = pq.PQping +PQping.argtypes = [c_char_p] +PQping.restype = c_int + + +# 33.2. Connection Status Functions + +PQdb = pq.PQdb +PQdb.argtypes = [PGconn_ptr] +PQdb.restype = c_char_p + +PQuser = pq.PQuser +PQuser.argtypes = [PGconn_ptr] +PQuser.restype = c_char_p + +PQpass = pq.PQpass +PQpass.argtypes = [PGconn_ptr] +PQpass.restype = c_char_p + +PQhost = pq.PQhost +PQhost.argtypes = [PGconn_ptr] +PQhost.restype = c_char_p + +_PQhostaddr = None + +if libpq_version >= 120000: + _PQhostaddr = pq.PQhostaddr + _PQhostaddr.argtypes = [PGconn_ptr] + _PQhostaddr.restype = c_char_p + + +def PQhostaddr(pgconn: PGconn_struct) -> bytes: + if not _PQhostaddr: + raise NotSupportedError( + "PQhostaddr requires libpq from PostgreSQL 12," + f" {libpq_version} available instead" + ) + + return _PQhostaddr(pgconn) + + +PQport = pq.PQport +PQport.argtypes = [PGconn_ptr] +PQport.restype = c_char_p + +PQtty = pq.PQtty +PQtty.argtypes = [PGconn_ptr] +PQtty.restype = c_char_p + +PQoptions = pq.PQoptions +PQoptions.argtypes = [PGconn_ptr] +PQoptions.restype = c_char_p + +PQstatus = pq.PQstatus +PQstatus.argtypes = [PGconn_ptr] +PQstatus.restype = c_int + +PQtransactionStatus = pq.PQtransactionStatus +PQtransactionStatus.argtypes = [PGconn_ptr] +PQtransactionStatus.restype = c_int + +PQparameterStatus = pq.PQparameterStatus +PQparameterStatus.argtypes = [PGconn_ptr, c_char_p] +PQparameterStatus.restype = c_char_p + +PQprotocolVersion = pq.PQprotocolVersion +PQprotocolVersion.argtypes = [PGconn_ptr] +PQprotocolVersion.restype = c_int + +PQserverVersion = pq.PQserverVersion +PQserverVersion.argtypes = [PGconn_ptr] +PQserverVersion.restype = c_int + +PQerrorMessage = pq.PQerrorMessage +PQerrorMessage.argtypes = [PGconn_ptr] +PQerrorMessage.restype = c_char_p + +PQsocket = pq.PQsocket +PQsocket.argtypes = [PGconn_ptr] +PQsocket.restype = c_int + +PQbackendPID = pq.PQbackendPID +PQbackendPID.argtypes = [PGconn_ptr] +PQbackendPID.restype = c_int + +PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword +PQconnectionNeedsPassword.argtypes = [PGconn_ptr] +PQconnectionNeedsPassword.restype = c_int + +PQconnectionUsedPassword = pq.PQconnectionUsedPassword +PQconnectionUsedPassword.argtypes = [PGconn_ptr] +PQconnectionUsedPassword.restype = c_int + +PQsslInUse = pq.PQsslInUse +PQsslInUse.argtypes = [PGconn_ptr] +PQsslInUse.restype = c_int + +# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl + + +# 33.3. Command Execution Functions + +PQexec = pq.PQexec +PQexec.argtypes = [PGconn_ptr, c_char_p] +PQexec.restype = PGresult_ptr + +PQexecParams = pq.PQexecParams +PQexecParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecParams.restype = PGresult_ptr + +PQprepare = pq.PQprepare +PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQprepare.restype = PGresult_ptr + +PQexecPrepared = pq.PQexecPrepared +PQexecPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecPrepared.restype = PGresult_ptr + +PQdescribePrepared = pq.PQdescribePrepared +PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQdescribePrepared.restype = PGresult_ptr + +PQdescribePortal = pq.PQdescribePortal +PQdescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQdescribePortal.restype = PGresult_ptr + +PQresultStatus = pq.PQresultStatus +PQresultStatus.argtypes = [PGresult_ptr] +PQresultStatus.restype = c_int + +# PQresStatus: not needed, we have pretty enums + +PQresultErrorMessage = pq.PQresultErrorMessage +PQresultErrorMessage.argtypes = [PGresult_ptr] +PQresultErrorMessage.restype = c_char_p + +# TODO: PQresultVerboseErrorMessage + +PQresultErrorField = pq.PQresultErrorField +PQresultErrorField.argtypes = [PGresult_ptr, c_int] +PQresultErrorField.restype = c_char_p + +PQclear = pq.PQclear +PQclear.argtypes = [PGresult_ptr] +PQclear.restype = None + + +# 33.3.2. Retrieving Query Result Information + +PQntuples = pq.PQntuples +PQntuples.argtypes = [PGresult_ptr] +PQntuples.restype = c_int + +PQnfields = pq.PQnfields +PQnfields.argtypes = [PGresult_ptr] +PQnfields.restype = c_int + +PQfname = pq.PQfname +PQfname.argtypes = [PGresult_ptr, c_int] +PQfname.restype = c_char_p + +# PQfnumber: useless and hard to use + +PQftable = pq.PQftable +PQftable.argtypes = [PGresult_ptr, c_int] +PQftable.restype = Oid + +PQftablecol = pq.PQftablecol +PQftablecol.argtypes = [PGresult_ptr, c_int] +PQftablecol.restype = c_int + +PQfformat = pq.PQfformat +PQfformat.argtypes = [PGresult_ptr, c_int] +PQfformat.restype = c_int + +PQftype = pq.PQftype +PQftype.argtypes = [PGresult_ptr, c_int] +PQftype.restype = Oid + +PQfmod = pq.PQfmod +PQfmod.argtypes = [PGresult_ptr, c_int] +PQfmod.restype = c_int + +PQfsize = pq.PQfsize +PQfsize.argtypes = [PGresult_ptr, c_int] +PQfsize.restype = c_int + +PQbinaryTuples = pq.PQbinaryTuples +PQbinaryTuples.argtypes = [PGresult_ptr] +PQbinaryTuples.restype = c_int + +PQgetvalue = pq.PQgetvalue +PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int] +PQgetvalue.restype = POINTER(c_char) # not a null-terminated string + +PQgetisnull = pq.PQgetisnull +PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int] +PQgetisnull.restype = c_int + +PQgetlength = pq.PQgetlength +PQgetlength.argtypes = [PGresult_ptr, c_int, c_int] +PQgetlength.restype = c_int + +PQnparams = pq.PQnparams +PQnparams.argtypes = [PGresult_ptr] +PQnparams.restype = c_int + +PQparamtype = pq.PQparamtype +PQparamtype.argtypes = [PGresult_ptr, c_int] +PQparamtype.restype = Oid + +# PQprint: pretty useless + +# 33.3.3. Retrieving Other Result Information + +PQcmdStatus = pq.PQcmdStatus +PQcmdStatus.argtypes = [PGresult_ptr] +PQcmdStatus.restype = c_char_p + +PQcmdTuples = pq.PQcmdTuples +PQcmdTuples.argtypes = [PGresult_ptr] +PQcmdTuples.restype = c_char_p + +PQoidValue = pq.PQoidValue +PQoidValue.argtypes = [PGresult_ptr] +PQoidValue.restype = Oid + + +# 33.3.4. Escaping Strings for Inclusion in SQL Commands + +PQescapeLiteral = pq.PQescapeLiteral +PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeLiteral.restype = POINTER(c_char) + +PQescapeIdentifier = pq.PQescapeIdentifier +PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeIdentifier.restype = POINTER(c_char) + +PQescapeStringConn = pq.PQescapeStringConn +# TODO: raises "wrong type" error +# PQescapeStringConn.argtypes = [ +# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int) +# ] +PQescapeStringConn.restype = c_size_t + +PQescapeString = pq.PQescapeString +# TODO: raises "wrong type" error +# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t] +PQescapeString.restype = c_size_t + +PQescapeByteaConn = pq.PQescapeByteaConn +PQescapeByteaConn.argtypes = [ + PGconn_ptr, + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeByteaConn.restype = POINTER(c_ubyte) + +PQescapeBytea = pq.PQescapeBytea +PQescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeBytea.restype = POINTER(c_ubyte) + + +PQunescapeBytea = pq.PQunescapeBytea +PQunescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + POINTER(c_size_t), +] +PQunescapeBytea.restype = POINTER(c_ubyte) + + +# 33.4. Asynchronous Command Processing + +PQsendQuery = pq.PQsendQuery +PQsendQuery.argtypes = [PGconn_ptr, c_char_p] +PQsendQuery.restype = c_int + +PQsendQueryParams = pq.PQsendQueryParams +PQsendQueryParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryParams.restype = c_int + +PQsendPrepare = pq.PQsendPrepare +PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQsendPrepare.restype = c_int + +PQsendQueryPrepared = pq.PQsendQueryPrepared +PQsendQueryPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryPrepared.restype = c_int + +PQsendDescribePrepared = pq.PQsendDescribePrepared +PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePrepared.restype = c_int + +PQsendDescribePortal = pq.PQsendDescribePortal +PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePortal.restype = c_int + +PQgetResult = pq.PQgetResult +PQgetResult.argtypes = [PGconn_ptr] +PQgetResult.restype = PGresult_ptr + +PQconsumeInput = pq.PQconsumeInput +PQconsumeInput.argtypes = [PGconn_ptr] +PQconsumeInput.restype = c_int + +PQisBusy = pq.PQisBusy +PQisBusy.argtypes = [PGconn_ptr] +PQisBusy.restype = c_int + +PQsetnonblocking = pq.PQsetnonblocking +PQsetnonblocking.argtypes = [PGconn_ptr, c_int] +PQsetnonblocking.restype = c_int + +PQisnonblocking = pq.PQisnonblocking +PQisnonblocking.argtypes = [PGconn_ptr] +PQisnonblocking.restype = c_int + +PQflush = pq.PQflush +PQflush.argtypes = [PGconn_ptr] +PQflush.restype = c_int + + +# 33.5. Retrieving Query Results Row-by-Row +PQsetSingleRowMode = pq.PQsetSingleRowMode +PQsetSingleRowMode.argtypes = [PGconn_ptr] +PQsetSingleRowMode.restype = c_int + + +# 33.6. Canceling Queries in Progress + +PQgetCancel = pq.PQgetCancel +PQgetCancel.argtypes = [PGconn_ptr] +PQgetCancel.restype = PGcancel_ptr + +PQfreeCancel = pq.PQfreeCancel +PQfreeCancel.argtypes = [PGcancel_ptr] +PQfreeCancel.restype = None + +PQcancel = pq.PQcancel +# TODO: raises "wrong type" error +# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int] +PQcancel.restype = c_int + + +# 33.8. Asynchronous Notification + +PQnotifies = pq.PQnotifies +PQnotifies.argtypes = [PGconn_ptr] +PQnotifies.restype = PGnotify_ptr + + +# 33.9. Functions Associated with the COPY Command + +PQputCopyData = pq.PQputCopyData +PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int] +PQputCopyData.restype = c_int + +PQputCopyEnd = pq.PQputCopyEnd +PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p] +PQputCopyEnd.restype = c_int + +PQgetCopyData = pq.PQgetCopyData +PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int] +PQgetCopyData.restype = c_int + + +# 33.10. Control Functions + +PQtrace = pq.PQtrace +PQtrace.argtypes = [PGconn_ptr, FILE_ptr] +PQtrace.restype = None + +_PQsetTraceFlags = None + +if libpq_version >= 140000: + _PQsetTraceFlags = pq.PQsetTraceFlags + _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int] + _PQsetTraceFlags.restype = None + + +def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None: + if not _PQsetTraceFlags: + raise NotSupportedError( + "PQsetTraceFlags requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + + _PQsetTraceFlags(pgconn, flags) + + +PQuntrace = pq.PQuntrace +PQuntrace.argtypes = [PGconn_ptr] +PQuntrace.restype = None + +# 33.11. Miscellaneous Functions + +PQfreemem = pq.PQfreemem +PQfreemem.argtypes = [c_void_p] +PQfreemem.restype = None + +if libpq_version >= 100000: + _PQencryptPasswordConn = pq.PQencryptPasswordConn + _PQencryptPasswordConn.argtypes = [ + PGconn_ptr, + c_char_p, + c_char_p, + c_char_p, + ] + _PQencryptPasswordConn.restype = POINTER(c_char) + + +def PQencryptPasswordConn( + pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes +) -> Optional[bytes]: + if not _PQencryptPasswordConn: + raise NotSupportedError( + "PQencryptPasswordConn requires libpq from PostgreSQL 10," + f" {libpq_version} available instead" + ) + + return _PQencryptPasswordConn(pgconn, passwd, user, algorithm) + + +PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult +PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int] +PQmakeEmptyPGresult.restype = PGresult_ptr + +PQsetResultAttrs = pq.PQsetResultAttrs +PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr] +PQsetResultAttrs.restype = c_int + + +# 33.12. Notice Processing + +PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr) + +PQsetNoticeReceiver = pq.PQsetNoticeReceiver +PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p] +PQsetNoticeReceiver.restype = PQnoticeReceiver + +# 34.5 Pipeline Mode + +_PQpipelineStatus = None +_PQenterPipelineMode = None +_PQexitPipelineMode = None +_PQpipelineSync = None +_PQsendFlushRequest = None + +if libpq_version >= 140000: + _PQpipelineStatus = pq.PQpipelineStatus + _PQpipelineStatus.argtypes = [PGconn_ptr] + _PQpipelineStatus.restype = c_int + + _PQenterPipelineMode = pq.PQenterPipelineMode + _PQenterPipelineMode.argtypes = [PGconn_ptr] + _PQenterPipelineMode.restype = c_int + + _PQexitPipelineMode = pq.PQexitPipelineMode + _PQexitPipelineMode.argtypes = [PGconn_ptr] + _PQexitPipelineMode.restype = c_int + + _PQpipelineSync = pq.PQpipelineSync + _PQpipelineSync.argtypes = [PGconn_ptr] + _PQpipelineSync.restype = c_int + + _PQsendFlushRequest = pq.PQsendFlushRequest + _PQsendFlushRequest.argtypes = [PGconn_ptr] + _PQsendFlushRequest.restype = c_int + + +def PQpipelineStatus(pgconn: PGconn_struct) -> int: + if not _PQpipelineStatus: + raise NotSupportedError( + "PQpipelineStatus requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineStatus(pgconn) + + +def PQenterPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQenterPipelineMode: + raise NotSupportedError( + "PQenterPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQenterPipelineMode(pgconn) + + +def PQexitPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQexitPipelineMode: + raise NotSupportedError( + "PQexitPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQexitPipelineMode(pgconn) + + +def PQpipelineSync(pgconn: PGconn_struct) -> int: + if not _PQpipelineSync: + raise NotSupportedError( + "PQpipelineSync requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineSync(pgconn) + + +def PQsendFlushRequest(pgconn: PGconn_struct) -> int: + if not _PQsendFlushRequest: + raise NotSupportedError( + "PQsendFlushRequest requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQsendFlushRequest(pgconn) + + +# 33.18. SSL Support + +PQinitOpenSSL = pq.PQinitOpenSSL +PQinitOpenSSL.argtypes = [c_int, c_int] +PQinitOpenSSL.restype = None + + +def generate_stub() -> None: + import re + from ctypes import _CFuncPtr # type: ignore + + def type2str(fname, narg, t): + if t is None: + return "None" + elif t is c_void_p: + return "Any" + elif t is c_int or t is c_uint or t is c_size_t: + return "int" + elif t is c_char_p or t.__name__ == "LP_c_char": + if narg is not None: + return "bytes" + else: + return "Optional[bytes]" + + elif t.__name__ in ( + "LP_PGconn_struct", + "LP_PGresult_struct", + "LP_PGcancel_struct", + ): + if narg is not None: + return f"Optional[{t.__name__[3:]}]" + else: + return t.__name__[3:] + + elif t.__name__ in ("LP_PQconninfoOption_struct",): + return f"Sequence[{t.__name__[3:]}]" + + elif t.__name__ in ( + "LP_c_ubyte", + "LP_c_char_p", + "LP_c_int", + "LP_c_uint", + "LP_c_ulong", + "LP_FILE", + ): + return f"_Pointer[{t.__name__[3:]}]" + + else: + assert False, f"can't deal with {t} in {fname}" + + fn = __file__ + "i" + with open(fn) as f: + lines = f.read().splitlines() + + istart, iend = ( + i + for i, line in enumerate(lines) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) + ) + + known = { + line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ") + } + + signatures = [] + + for name, obj in globals().items(): + if name in known: + continue + if not isinstance(obj, _CFuncPtr): + continue + + params = [] + for i, t in enumerate(obj.argtypes): + params.append(f"arg{i + 1}: {type2str(name, i, t)}") + + resname = type2str(name, None, obj.restype) + + signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...") + + lines[istart + 1 : iend] = signatures + + with open(fn, "w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +if __name__ == "__main__": + generate_stub() diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi new file mode 100644 index 0000000..5d2ee3f --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.pyi @@ -0,0 +1,216 @@ +""" +types stub for ctypes functions +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Optional, Sequence +from ctypes import Array, pointer, _Pointer +from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong + +class FILE: ... + +def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var] + +Oid = c_uint + +class PGconn_struct: ... +class PGresult_struct: ... +class PGcancel_struct: ... + +class PQconninfoOption_struct: + keyword: bytes + envvar: bytes + compiled: bytes + val: bytes + label: bytes + dispchar: bytes + dispsize: int + +class PGnotify_struct: + be_pid: int + relname: bytes + extra: bytes + +class PGresAttDesc_struct: + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + +def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQexecPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> PGresult_struct: ... +def PQprepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> PGresult_struct: ... +def PQgetvalue( + arg1: Optional[PGresult_struct], arg2: int, arg3: int +) -> _Pointer[c_char]: ... +def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQescapeStringConn( + arg1: Optional[PGconn_struct], + arg2: c_char_p, + arg3: bytes, + arg4: int, + arg5: _Pointer[c_int], +) -> int: ... +def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ... +def PQsendPrepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> int: ... +def PQsendQueryPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> int: ... +def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ... +def PQsetNoticeReceiver( + arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any +) -> Callable[[Any], PGresult_struct]: ... + +# TODO: Ignoring type as getting an error on mypy/ctypes: +# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be +# a subtype of "ctypes._CData" +def PQnotifies( + arg1: Optional[PGconn_struct], +) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore +def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ... + +# Arg 2 is a _Pointer, reported as _CArgObject by mypy +def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ... +def PQsetResultAttrs( + arg1: Optional[PGresult_struct], + arg2: int, + arg3: Array[PGresAttDesc_struct], # type: ignore +) -> int: ... +def PQtrace( + arg1: Optional[PGconn_struct], + arg2: _Pointer[FILE], # type: ignore[type-var] +) -> None: ... +def PQencryptPasswordConn( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: Optional[bytes], +) -> bytes: ... +def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ... +def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ... +def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ... + +# fmt: off +# autogenerated: start +def PQlibVersion() -> int: ... +def PQconnectdb(arg1: bytes) -> PGconn_struct: ... +def PQconnectStart(arg1: bytes) -> PGconn_struct: ... +def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ... +def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ... +def PQfinish(arg1: Optional[PGconn_struct]) -> None: ... +def PQreset(arg1: Optional[PGconn_struct]) -> None: ... +def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ... +def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQping(arg1: bytes) -> int: ... +def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQstatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ... +def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQsocket(arg1: Optional[PGconn_struct]) -> int: ... +def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ... +def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ... +def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ... +def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQclear(arg1: Optional[PGresult_struct]) -> None: ... +def PQntuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQnfields(arg1: Optional[PGresult_struct]) -> int: ... +def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... +def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... +def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... +def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ... +def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ... +def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ... +def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ... +def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ... +def PQflush(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ... +def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ... +def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ... +def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ... +def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ... +def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ... +def PQfreemem(arg1: Any) -> None: ... +def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ... +def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ... +def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ... +def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ... +def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ... +def PQinitOpenSSL(arg1: int, arg2: int) -> None: ... +# autogenerated: end +# fmt: on + +# vim: set syntax=python: diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py new file mode 100644 index 0000000..9c45f64 --- /dev/null +++ b/psycopg/psycopg/pq/abc.py @@ -0,0 +1,385 @@ +""" +Protocol objects to represent objects exposed by different pq implementations. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from ._enums import Format, Trace +from .._compat import Protocol + +if TYPE_CHECKING: + from .misc import PGnotify, ConninfoOption, PGresAttDesc + +# An object implementing the buffer protocol (ish) +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + + +class PGconn(Protocol): + + notice_handler: Optional[Callable[["PGresult"], None]] + notify_handler: Optional[Callable[["PGnotify"], None]] + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + ... + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + ... + + def connect_poll(self) -> int: + ... + + def finish(self) -> None: + ... + + @property + def info(self) -> List["ConninfoOption"]: + ... + + def reset(self) -> None: + ... + + def reset_start(self) -> None: + ... + + def reset_poll(self) -> int: + ... + + @classmethod + def ping(self, conninfo: bytes) -> int: + ... + + @property + def db(self) -> bytes: + ... + + @property + def user(self) -> bytes: + ... + + @property + def password(self) -> bytes: + ... + + @property + def host(self) -> bytes: + ... + + @property + def hostaddr(self) -> bytes: + ... + + @property + def port(self) -> bytes: + ... + + @property + def tty(self) -> bytes: + ... + + @property + def options(self) -> bytes: + ... + + @property + def status(self) -> int: + ... + + @property + def transaction_status(self) -> int: + ... + + def parameter_status(self, name: bytes) -> Optional[bytes]: + ... + + @property + def error_message(self) -> bytes: + ... + + @property + def server_version(self) -> int: + ... + + @property + def socket(self) -> int: + ... + + @property + def backend_pid(self) -> int: + ... + + @property + def needs_password(self) -> bool: + ... + + @property + def used_password(self) -> bool: + ... + + @property + def ssl_in_use(self) -> bool: + ... + + def exec_(self, command: bytes) -> "PGresult": + ... + + def send_query(self, command: bytes) -> None: + ... + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + ... + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + ... + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + ... + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Buffer]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + ... + + def describe_prepared(self, name: bytes) -> "PGresult": + ... + + def send_describe_prepared(self, name: bytes) -> None: + ... + + def describe_portal(self, name: bytes) -> "PGresult": + ... + + def send_describe_portal(self, name: bytes) -> None: + ... + + def get_result(self) -> Optional["PGresult"]: + ... + + def consume_input(self) -> None: + ... + + def is_busy(self) -> int: + ... + + @property + def nonblocking(self) -> int: + ... + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + ... + + def flush(self) -> int: + ... + + def set_single_row_mode(self) -> None: + ... + + def get_cancel(self) -> "PGcancel": + ... + + def notifies(self) -> Optional["PGnotify"]: + ... + + def put_copy_data(self, buffer: Buffer) -> int: + ... + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + ... + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + ... + + def trace(self, fileno: int) -> None: + ... + + def set_trace_flags(self, flags: Trace) -> None: + ... + + def untrace(self) -> None: + ... + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + ... + + def make_empty_result(self, exec_status: int) -> "PGresult": + ... + + @property + def pipeline_status(self) -> int: + ... + + def enter_pipeline_mode(self) -> None: + ... + + def exit_pipeline_mode(self) -> None: + ... + + def pipeline_sync(self) -> None: + ... + + def send_flush_request(self) -> None: + ... + + +class PGresult(Protocol): + def clear(self) -> None: + ... + + @property + def status(self) -> int: + ... + + @property + def error_message(self) -> bytes: + ... + + def error_field(self, fieldcode: int) -> Optional[bytes]: + ... + + @property + def ntuples(self) -> int: + ... + + @property + def nfields(self) -> int: + ... + + def fname(self, column_number: int) -> Optional[bytes]: + ... + + def ftable(self, column_number: int) -> int: + ... + + def ftablecol(self, column_number: int) -> int: + ... + + def fformat(self, column_number: int) -> int: + ... + + def ftype(self, column_number: int) -> int: + ... + + def fmod(self, column_number: int) -> int: + ... + + def fsize(self, column_number: int) -> int: + ... + + @property + def binary_tuples(self) -> int: + ... + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + ... + + @property + def nparams(self) -> int: + ... + + def param_type(self, param_number: int) -> int: + ... + + @property + def command_status(self) -> Optional[bytes]: + ... + + @property + def command_tuples(self) -> Optional[int]: + ... + + @property + def oid_value(self) -> int: + ... + + def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: + ... + + +class PGcancel(Protocol): + def free(self) -> None: + ... + + def cancel(self) -> None: + ... + + +class Conninfo(Protocol): + @classmethod + def get_defaults(cls) -> List["ConninfoOption"]: + ... + + @classmethod + def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: + ... + + @classmethod + def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: + ... + + +class Escaping(Protocol): + def __init__(self, conn: Optional[PGconn] = None): + ... + + def escape_literal(self, data: Buffer) -> bytes: + ... + + def escape_identifier(self, data: Buffer) -> bytes: + ... + + def escape_string(self, data: Buffer) -> bytes: + ... + + def escape_bytea(self, data: Buffer) -> bytes: + ... + + def unescape_bytea(self, data: Buffer) -> bytes: + ... diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py new file mode 100644 index 0000000..3a43133 --- /dev/null +++ b/psycopg/psycopg/pq/misc.py @@ -0,0 +1,146 @@ +""" +Various functionalities to make easier to work with the libpq. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import sys +import logging +import ctypes.util +from typing import cast, NamedTuple, Optional, Union + +from .abc import PGconn, PGresult +from ._enums import ConnStatus, TransactionStatus, PipelineStatus +from .._compat import cache +from .._encodings import pgconn_encoding + +logger = logging.getLogger("psycopg.pq") + +OK = ConnStatus.OK + + +class PGnotify(NamedTuple): + relname: bytes + be_pid: int + extra: bytes + + +class ConninfoOption(NamedTuple): + keyword: bytes + envvar: Optional[bytes] + compiled: Optional[bytes] + val: Optional[bytes] + label: bytes + dispchar: bytes + dispsize: int + + +class PGresAttDesc(NamedTuple): + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + + +@cache +def find_libpq_full_path() -> Optional[str]: + if sys.platform == "win32": + libname = ctypes.util.find_library("libpq.dll") + + elif sys.platform == "darwin": + libname = ctypes.util.find_library("libpq.dylib") + # (hopefully) temporary hack: libpq not in a standard place + # https://github.com/orgs/Homebrew/discussions/3595 + # If pg_config is available and agrees, let's use its indications. + if not libname: + try: + import subprocess as sp + + libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode() + libname = os.path.join(libdir, "libpq.dylib") + if not os.path.exists(libname): + libname = None + except Exception as ex: + logger.debug("couldn't use pg_config to find libpq: %s", ex) + + else: + libname = ctypes.util.find_library("pq") + + return libname + + +def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str: + """ + Return an error message from a `PGconn` or `PGresult`. + + The return value is a `!str` (unlike pq data which is usually `!bytes`): + use the connection encoding if available, otherwise the `!encoding` + parameter as a fallback for decoding. Don't raise exceptions on decoding + errors. + + """ + bmsg: bytes + + if hasattr(obj, "error_field"): + # obj is a PGresult + obj = cast(PGresult, obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + elif hasattr(obj, "error_message"): + # obj is a PGconn + if obj.status == OK: + encoding = pgconn_encoding(obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + else: + raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}") + + if bmsg: + msg = bmsg.decode(encoding, "replace") + else: + msg = "no details available" + + return msg + + +def connection_summary(pgconn: PGconn) -> str: + """ + Return summary information on a connection. + + Useful for __repr__ + """ + parts = [] + if pgconn.status == OK: + # Put together the [STATUS] + status = TransactionStatus(pgconn.transaction_status).name + if pgconn.pipeline_status: + status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}" + + # Put together the (CONNECTION) + if not pgconn.host.startswith(b"/"): + parts.append(("host", pgconn.host.decode())) + if pgconn.port != b"5432": + parts.append(("port", pgconn.port.decode())) + if pgconn.user != pgconn.db: + parts.append(("user", pgconn.user.decode())) + parts.append(("database", pgconn.db.decode())) + + else: + status = ConnStatus(pgconn.status).name + + sparts = " ".join("%s=%s" % part for part in parts) + if sparts: + sparts = f" ({sparts})" + return f"[{status}]{sparts}" diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py new file mode 100644 index 0000000..8b87c19 --- /dev/null +++ b/psycopg/psycopg/pq/pq_ctypes.py @@ -0,0 +1,1086 @@ +""" +libpq Python wrapper using ctypes bindings. + +Clients shouldn't use this module directly, unless for testing: they should use +the `pq` module instead, which is in charge of choosing the best +implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import logging +from os import getpid +from weakref import ref + +from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref +from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import cast as t_cast, TYPE_CHECKING + +from .. import errors as e +from . import _pq_ctypes as impl +from .misc import PGnotify, ConninfoOption, PGresAttDesc +from .misc import error_message, connection_summary +from ._enums import Format, ExecStatus, Trace + +# Imported locally to call them from __del__ methods +from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus + +if TYPE_CHECKING: + from . import abc + +__impl__ = "python" + +logger = logging.getLogger("psycopg") + + +def version() -> int: + """Return the version number of the libpq currently loaded. + + The number is in the same format of `~psycopg.ConnectionInfo.server_version`. + + Certain features might not be available if the libpq library used is too old. + """ + return impl.PQlibVersion() + + +@impl.PQnoticeReceiver # type: ignore +def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None: + pgconn = cast(arg, POINTER(py_object)).contents.value() + if not (pgconn and pgconn.notice_handler): + return + + res = PGresult(result_ptr) + try: + pgconn.notice_handler(res) + except Exception as exc: + logger.exception("error in notice receiver: %s", exc) + finally: + res._pgresult_ptr = None # avoid destroying the pgresult_ptr + + +class PGconn: + """ + Python representation of a libpq connection. + """ + + __slots__ = ( + "_pgconn_ptr", + "notice_handler", + "notify_handler", + "_self_ptr", + "_procpid", + "__weakref__", + ) + + def __init__(self, pgconn_ptr: impl.PGconn_struct): + self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr + self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None + self.notify_handler: Optional[Callable[[PGnotify], None]] = None + + # Keep alive for the lifetime of PGconn + self._self_ptr = py_object(ref(self)) + impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr)) + + self._procpid = getpid() + + def __del__(self) -> None: + # Close the connection only if it was created in this process, + # not if this object is being GC'd after fork. + if getpid() == self._procpid: + self.finish() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self) + return f"<{cls} {info} at 0x{id(self):x}>" + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectdb(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectStart(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + def connect_poll(self) -> int: + return self._call_int(impl.PQconnectPoll) + + def finish(self) -> None: + self._pgconn_ptr, p = None, self._pgconn_ptr + if p: + PQfinish(p) + + @property + def pgconn_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGconn` structure, as integer. + + `!None` if the connection is closed. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgconn_ptr is None: + return None + + return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined] + + @property + def info(self) -> List["ConninfoOption"]: + self._ensure_pgconn() + opts = impl.PQconninfo(self._pgconn_ptr) + if not opts: + raise MemoryError("couldn't allocate connection info") + try: + return Conninfo._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + def reset(self) -> None: + self._ensure_pgconn() + impl.PQreset(self._pgconn_ptr) + + def reset_start(self) -> None: + if not impl.PQresetStart(self._pgconn_ptr): + raise e.OperationalError("couldn't reset connection") + + def reset_poll(self) -> int: + return self._call_int(impl.PQresetPoll) + + @classmethod + def ping(self, conninfo: bytes) -> int: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + return impl.PQping(conninfo) + + @property + def db(self) -> bytes: + return self._call_bytes(impl.PQdb) + + @property + def user(self) -> bytes: + return self._call_bytes(impl.PQuser) + + @property + def password(self) -> bytes: + return self._call_bytes(impl.PQpass) + + @property + def host(self) -> bytes: + return self._call_bytes(impl.PQhost) + + @property + def hostaddr(self) -> bytes: + return self._call_bytes(impl.PQhostaddr) + + @property + def port(self) -> bytes: + return self._call_bytes(impl.PQport) + + @property + def tty(self) -> bytes: + return self._call_bytes(impl.PQtty) + + @property + def options(self) -> bytes: + return self._call_bytes(impl.PQoptions) + + @property + def status(self) -> int: + return PQstatus(self._pgconn_ptr) + + @property + def transaction_status(self) -> int: + return impl.PQtransactionStatus(self._pgconn_ptr) + + def parameter_status(self, name: bytes) -> Optional[bytes]: + self._ensure_pgconn() + return impl.PQparameterStatus(self._pgconn_ptr, name) + + @property + def error_message(self) -> bytes: + return impl.PQerrorMessage(self._pgconn_ptr) + + @property + def protocol_version(self) -> int: + return self._call_int(impl.PQprotocolVersion) + + @property + def server_version(self) -> int: + return self._call_int(impl.PQserverVersion) + + @property + def socket(self) -> int: + rv = self._call_int(impl.PQsocket) + if rv == -1: + raise e.OperationalError("the connection is lost") + return rv + + @property + def backend_pid(self) -> int: + return self._call_int(impl.PQbackendPID) + + @property + def needs_password(self) -> bool: + """True if the connection authentication method required a password, + but none was available. + + See :pq:`PQconnectionNeedsPassword` for details. + """ + return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr)) + + @property + def used_password(self) -> bool: + """True if the connection authentication method used a password. + + See :pq:`PQconnectionUsedPassword` for details. + """ + return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr)) + + @property + def ssl_in_use(self) -> bool: + return self._call_bool(impl.PQsslInUse) + + def exec_(self, command: bytes) -> "PGresult": + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + rv = impl.PQexec(self._pgconn_ptr, command) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query(self, command: bytes) -> None: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + if not impl.PQsendQuery(self._pgconn_ptr, command): + raise e.OperationalError(f"sending query failed: {error_message(self)}") + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + rv = impl.PQexecParams(*args) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + if not impl.PQsendQueryParams(*args): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + atypes: Optional[Array[impl.Oid]] + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + # repurpose this function with a cheeky replacement of query with name, + # drop the param_types from the result + args = self._query_params_args( + name, param_values, None, param_formats, result_format + ) + args = args[:3] + args[4:] + + self._ensure_pgconn() + if not impl.PQsendQueryPrepared(*args): + raise e.OperationalError( + f"sending prepared query failed: {error_message(self)}" + ) + + def _query_params_args( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> Any: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + atypes: Optional[Array[impl.Oid]] + if not param_types: + atypes = None + else: + if len(param_types) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_types)) + ) + atypes = (impl.Oid * nparams)(*param_types) + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_formats" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + return ( + self._pgconn_ptr, + command, + nparams, + atypes, + aparams, + alenghts, + aformats, + result_format, + ) + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + if not isinstance(command, bytes): + raise TypeError(f"'command' must be bytes, got {type(command)} instead") + + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence["abc.Buffer"]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + self._ensure_pgconn() + rv = impl.PQexecPrepared( + self._pgconn_ptr, + name, + nparams, + aparams, + alenghts, + aformats, + result_format, + ) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def describe_prepared(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePrepared(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_prepared(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePrepared(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def describe_portal(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePortal(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_portal(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePortal(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe portal failed: {error_message(self)}" + ) + + def get_result(self) -> Optional["PGresult"]: + rv = impl.PQgetResult(self._pgconn_ptr) + return PGresult(rv) if rv else None + + def consume_input(self) -> None: + if 1 != impl.PQconsumeInput(self._pgconn_ptr): + raise e.OperationalError(f"consuming input failed: {error_message(self)}") + + def is_busy(self) -> int: + return impl.PQisBusy(self._pgconn_ptr) + + @property + def nonblocking(self) -> int: + return impl.PQisnonblocking(self._pgconn_ptr) + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg): + raise e.OperationalError( + f"setting nonblocking failed: {error_message(self)}" + ) + + def flush(self) -> int: + # PQflush segfaults if it receives a NULL connection + if not self._pgconn_ptr: + raise e.OperationalError("flushing failed: the connection is closed") + rv: int = impl.PQflush(self._pgconn_ptr) + if rv < 0: + raise e.OperationalError(f"flushing failed: {error_message(self)}") + return rv + + def set_single_row_mode(self) -> None: + if not impl.PQsetSingleRowMode(self._pgconn_ptr): + raise e.OperationalError("setting single row mode failed") + + def get_cancel(self) -> "PGcancel": + """ + Create an object with the information needed to cancel a command. + + See :pq:`PQgetCancel` for details. + """ + rv = impl.PQgetCancel(self._pgconn_ptr) + if not rv: + raise e.OperationalError("couldn't create cancel object") + return PGcancel(rv) + + def notifies(self) -> Optional[PGnotify]: + ptr = impl.PQnotifies(self._pgconn_ptr) + if ptr: + c = ptr.contents + return PGnotify(c.relname, c.be_pid, c.extra) + impl.PQfreemem(ptr) + else: + return None + + def put_copy_data(self, buffer: "abc.Buffer") -> int: + if not isinstance(buffer, bytes): + buffer = bytes(buffer) + rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer)) + if rv < 0: + raise e.OperationalError(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + rv = impl.PQputCopyEnd(self._pgconn_ptr, error) + if rv < 0: + raise e.OperationalError(f"sending copy end failed: {error_message(self)}") + return rv + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + buffer_ptr = c_char_p() + nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_) + if nbytes == -2: + raise e.OperationalError( + f"receiving copy data failed: {error_message(self)}" + ) + if buffer_ptr: + # TODO: do it without copy + data = string_at(buffer_ptr, nbytes) + impl.PQfreemem(buffer_ptr) + return nbytes, memoryview(data) + else: + return nbytes, memoryview(b"") + + def trace(self, fileno: int) -> None: + """ + Enable tracing of the client/server communication to a file stream. + + See :pq:`PQtrace` for details. + """ + if sys.platform != "linux": + raise e.NotSupportedError("currently only supported on Linux") + stream = impl.fdopen(fileno, b"w") + impl.PQtrace(self._pgconn_ptr, stream) + + def set_trace_flags(self, flags: Trace) -> None: + """ + Configure tracing behavior of client/server communication. + + :param flags: operating mode of tracing. + + See :pq:`PQsetTraceFlags` for details. + """ + impl.PQsetTraceFlags(self._pgconn_ptr, flags) + + def untrace(self) -> None: + """ + Disable tracing, previously enabled through `trace()`. + + See :pq:`PQuntrace` for details. + """ + impl.PQuntrace(self._pgconn_ptr) + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + """ + Return the encrypted form of a PostgreSQL password. + + See :pq:`PQencryptPasswordConn` for details. + """ + out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm) + if not out: + raise e.OperationalError( + f"password encryption failed: {error_message(self)}" + ) + + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def make_empty_result(self, exec_status: int) -> "PGresult": + rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status) + if not rv: + raise MemoryError("couldn't allocate empty PGresult") + return PGresult(rv) + + @property + def pipeline_status(self) -> int: + if version() < 140000: + return 0 + return impl.PQpipelineStatus(self._pgconn_ptr) + + def enter_pipeline_mode(self) -> None: + """Enter pipeline mode. + + :raises ~e.OperationalError: in case of failure to enter the pipeline + mode. + """ + if impl.PQenterPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError("failed to enter pipeline mode") + + def exit_pipeline_mode(self) -> None: + """Exit pipeline mode. + + :raises ~e.OperationalError: in case of failure to exit the pipeline + mode. + """ + if impl.PQexitPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError(error_message(self)) + + def pipeline_sync(self) -> None: + """Mark a synchronization point in a pipeline. + + :raises ~e.OperationalError: if the connection is not in pipeline mode + or if sync failed. + """ + rv = impl.PQpipelineSync(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError("connection not in pipeline mode") + if rv != 1: + raise e.OperationalError("failed to sync pipeline") + + def send_flush_request(self) -> None: + """Sends a request for the server to flush its output buffer. + + :raises ~e.OperationalError: if the flush request failed. + """ + if impl.PQsendFlushRequest(self._pgconn_ptr) == 0: + raise e.OperationalError(f"flush request failed: {error_message(self)}") + + def _call_bytes( + self, func: Callable[[impl.PGconn_struct], Optional[bytes]] + ) -> bytes: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + rv = func(self._pgconn_ptr) + assert rv is not None + return rv + + def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return func(self._pgconn_ptr) + + def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool: + """ + Call one of the pgconn libpq functions returning a logical value. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return bool(func(self._pgconn_ptr)) + + def _ensure_pgconn(self) -> None: + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + + +class PGresult: + """ + Python representation of a libpq result. + """ + + __slots__ = ("_pgresult_ptr",) + + def __init__(self, pgresult_ptr: impl.PGresult_struct): + self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr + + def __del__(self) -> None: + self.clear() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + status = ExecStatus(self.status) + return f"<{cls} [{status.name}] at 0x{id(self):x}>" + + def clear(self) -> None: + self._pgresult_ptr, p = None, self._pgresult_ptr + if p: + PQclear(p) + + @property + def pgresult_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGresult` structure, as integer. + + `!None` if the result was cleared. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgresult_ptr is None: + return None + + return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined] + + @property + def status(self) -> int: + return impl.PQresultStatus(self._pgresult_ptr) + + @property + def error_message(self) -> bytes: + return impl.PQresultErrorMessage(self._pgresult_ptr) + + def error_field(self, fieldcode: int) -> Optional[bytes]: + return impl.PQresultErrorField(self._pgresult_ptr, fieldcode) + + @property + def ntuples(self) -> int: + return impl.PQntuples(self._pgresult_ptr) + + @property + def nfields(self) -> int: + return impl.PQnfields(self._pgresult_ptr) + + def fname(self, column_number: int) -> Optional[bytes]: + return impl.PQfname(self._pgresult_ptr, column_number) + + def ftable(self, column_number: int) -> int: + return impl.PQftable(self._pgresult_ptr, column_number) + + def ftablecol(self, column_number: int) -> int: + return impl.PQftablecol(self._pgresult_ptr, column_number) + + def fformat(self, column_number: int) -> int: + return impl.PQfformat(self._pgresult_ptr, column_number) + + def ftype(self, column_number: int) -> int: + return impl.PQftype(self._pgresult_ptr, column_number) + + def fmod(self, column_number: int) -> int: + return impl.PQfmod(self._pgresult_ptr, column_number) + + def fsize(self, column_number: int) -> int: + return impl.PQfsize(self._pgresult_ptr, column_number) + + @property + def binary_tuples(self) -> int: + return impl.PQbinaryTuples(self._pgresult_ptr) + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number) + if length: + v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number) + return string_at(v, length) + else: + if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number): + return None + else: + return b"" + + @property + def nparams(self) -> int: + return impl.PQnparams(self._pgresult_ptr) + + def param_type(self, param_number: int) -> int: + return impl.PQparamtype(self._pgresult_ptr, param_number) + + @property + def command_status(self) -> Optional[bytes]: + return impl.PQcmdStatus(self._pgresult_ptr) + + @property + def command_tuples(self) -> Optional[int]: + rv = impl.PQcmdTuples(self._pgresult_ptr) + return int(rv) if rv else None + + @property + def oid_value(self) -> int: + return impl.PQoidValue(self._pgresult_ptr) + + def set_attributes(self, descriptions: List[PGresAttDesc]) -> None: + structs = [ + impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore + ] + array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore + rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) + if rv == 0: + raise e.OperationalError("PQsetResultAttrs failed") + + +class PGcancel: + """ + Token to cancel the current operation on a connection. + + Created by `PGconn.get_cancel()`. + """ + + __slots__ = ("pgcancel_ptr",) + + def __init__(self, pgcancel_ptr: impl.PGcancel_struct): + self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr + + def __del__(self) -> None: + self.free() + + def free(self) -> None: + """ + Free the data structure created by :pq:`PQgetCancel()`. + + Automatically invoked by `!__del__()`. + + See :pq:`PQfreeCancel()` for details. + """ + self.pgcancel_ptr, p = None, self.pgcancel_ptr + if p: + PQfreeCancel(p) + + def cancel(self) -> None: + """Requests that the server abandon processing of the current command. + + See :pq:`PQcancel()` for details. + """ + buf = create_string_buffer(256) + res = impl.PQcancel( + self.pgcancel_ptr, + byref(buf), # type: ignore[arg-type] + len(buf), + ) + if not res: + raise e.OperationalError( + f"cancel failed: {buf.value.decode('utf8', 'ignore')}" + ) + + +class Conninfo: + """ + Utility object to manipulate connection strings. + """ + + @classmethod + def get_defaults(cls) -> List[ConninfoOption]: + opts = impl.PQconndefaults() + if not opts: + raise MemoryError("couldn't allocate connection defaults") + try: + return cls._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + @classmethod + def parse(cls, conninfo: bytes) -> List[ConninfoOption]: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + errmsg = c_char_p() + rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type] + if not rv: + if not errmsg: + raise MemoryError("couldn't allocate on conninfo parse") + else: + exc = e.OperationalError( + (errmsg.value or b"").decode("utf8", "replace") + ) + impl.PQfreemem(errmsg) + raise exc + + try: + return cls._options_from_array(rv) + finally: + impl.PQconninfoFree(rv) + + @classmethod + def _options_from_array( + cls, opts: Sequence[impl.PQconninfoOption_struct] + ) -> List[ConninfoOption]: + rv = [] + skws = "keyword envvar compiled val label dispchar".split() + for opt in opts: + if not opt.keyword: + break + d = {kw: getattr(opt, kw) for kw in skws} + d["dispsize"] = opt.dispsize + rv.append(ConninfoOption(**d)) + + return rv + + +class Escaping: + """ + Utility object to escape strings for SQL interpolation. + """ + + def __init__(self, conn: Optional[PGconn] = None): + self.conn = conn + + def escape_literal(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_literal failed: no connection provided") + + self.conn._ensure_pgconn() + # TODO: might be done without copy (however C does that) + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_literal failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_identifier(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_identifier failed: no connection provided") + + self.conn._ensure_pgconn() + + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_string(self, data: "abc.Buffer") -> bytes: + if not isinstance(data, bytes): + data = bytes(data) + + if self.conn: + self.conn._ensure_pgconn() + error = c_int() + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeStringConn( + self.conn._pgconn_ptr, + byref(out), # type: ignore[arg-type] + data, + len(data), + byref(error), # type: ignore[arg-type] + ) + + if error: + raise e.OperationalError( + f"escape_string failed: {error_message(self.conn)} bytes" + ) + + else: + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeString( + byref(out), # type: ignore[arg-type] + data, + len(data), + ) + + return out.value + + def escape_bytea(self, data: "abc.Buffer") -> bytes: + len_out = c_size_t() + # TODO: might be able to do without a copy but it's a mess. + # the C library does it better anyway, so maybe not worth optimising + # https://mail.python.org/pipermail/python-dev/2012-September/121780.html + if not isinstance(data, bytes): + data = bytes(data) + if self.conn: + self.conn._ensure_pgconn() + out = impl.PQescapeByteaConn( + self.conn._pgconn_ptr, + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + else: + out = impl.PQescapeBytea( + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value - 1) # out includes final 0 + impl.PQfreemem(out) + return rv + + def unescape_bytea(self, data: "abc.Buffer") -> bytes: + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn: + self.conn._ensure_pgconn() + + len_out = c_size_t() + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQunescapeBytea( + data, + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value) + impl.PQfreemem(out) + return rv + + +# importing the ssl module sets up Python's libcrypto callbacks +import ssl # noqa + +# disable libcrypto setup in libpq, so it won't stomp on the callbacks +# that have already been set up +impl.PQinitOpenSSL(1, 0) + +__build_version__ = version() diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg/psycopg/py.typed diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py new file mode 100644 index 0000000..cb28b57 --- /dev/null +++ b/psycopg/psycopg/rows.py @@ -0,0 +1,256 @@ +""" +psycopg row factories +""" + +# Copyright (C) 2021 The Psycopg Team + +import functools +from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn +from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar +from collections import namedtuple +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from ._compat import Protocol +from ._encodings import _as_python_identifier + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from psycopg.pq.abc import PGresult + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE + +T = TypeVar("T", covariant=True) + +# Row factories + +Row = TypeVar("Row", covariant=True) + + +class RowMaker(Protocol[Row]): + """ + Callable protocol taking a sequence of value and returning an object. + + The sequence of value is what is returned from a database query, already + adapted to the right Python types. The return value is the object that your + program would like to receive: by default (`tuple_row()`) it is a simple + tuple, but it may be any type of object. + + Typically, `!RowMaker` functions are returned by `RowFactory`. + """ + + def __call__(self, __values: Sequence[Any]) -> Row: + ... + + +class RowFactory(Protocol[Row]): + """ + Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`. + + A `!RowFactory` is typically called when a `!Cursor` receives a result. + This way it can inspect the cursor state (for instance the + `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create + a complete object. + + For instance the `dict_row()` `!RowFactory` uses the names of the column to + define the dictionary key and returns a `!RowMaker` function which would + use the values to create a dictionary for each record. + """ + + def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: + ... + + +class AsyncRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking an async cursor as argument. + """ + + def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: + ... + + +class BaseRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking either type of cursor as argument. + """ + + def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: + ... + + +TupleRow: TypeAlias = Tuple[Any, ...] +""" +An alias for the type returned by `tuple_row()` (i.e. a tuple of any content). +""" + + +DictRow: TypeAlias = Dict[str, Any] +""" +An alias for the type returned by `dict_row()` + +A `!DictRow` is a dictionary with keys as string and any value returned by the +database. +""" + + +def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]": + r"""Row factory to represent rows as simple tuples. + + This is the default factory, used when `~psycopg.Connection.connect()` or + `~psycopg.Connection.cursor()` are called without a `!row_factory` + parameter. + + """ + # Implementation detail: make sure this is the tuple type itself, not an + # equivalent function, because the C code fast-paths on it. + return tuple + + +def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]": + """Row factory to represent rows as dictionaries. + + The dictionary keys are taken from the column names of the returned columns. + """ + names = _get_names(cursor) + if names is None: + return no_result + + def dict_row_(values: Sequence[Any]) -> Dict[str, Any]: + # https://github.com/python/mypy/issues/2608 + return dict(zip(names, values)) # type: ignore[arg-type] + + return dict_row_ + + +def namedtuple_row( + cursor: "BaseCursor[Any, Any]", +) -> "RowMaker[NamedTuple]": + """Row factory to represent rows as `~collections.namedtuple`. + + The field names are taken from the column names of the returned columns, + with some mangling to deal with invalid names. + """ + res = cursor.pgresult + if not res: + return no_result + + nfields = _get_nfields(res) + if nfields is None: + return no_result + + nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields))) + return nt._make + + +@functools.lru_cache(512) +def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]: + snames = tuple(_as_python_identifier(n.decode(enc)) for n in names) + return namedtuple("Row", snames) # type: ignore[return-value] + + +def class_row(cls: Type[T]) -> BaseRowFactory[T]: + r"""Generate a row factory to represent rows as instances of the class `!cls`. + + The class must support every output column name as a keyword parameter. + + :param cls: The class to return for each row. It must support the fields + returned by the query as keyword arguments. + :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]] + """ + + def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def class_row__(values: Sequence[Any]) -> T: + return cls(**dict(zip(names, values))) # type: ignore[arg-type] + + return class_row__ + + return class_row_ + + +def args_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with positional parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as positional arguments. + """ + + def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]": + def args_row__(values: Sequence[Any]) -> T: + return func(*values) + + return args_row__ + + return args_row_ + + +def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with keyword parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as keyword arguments. + """ + + def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def kwargs_row__(values: Sequence[Any]) -> T: + return func(**dict(zip(names, values))) # type: ignore[arg-type] + + return kwargs_row__ + + return kwargs_row_ + + +def no_result(values: Sequence[Any]) -> NoReturn: + """A `RowMaker` that always fail. + + It can be used as return value for a `RowFactory` called with no result. + Note that the `!RowFactory` *will* be called with no result, but the + resulting `!RowMaker` never should. + """ + raise e.InterfaceError("the cursor doesn't have a result") + + +def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]: + res = cursor.pgresult + if not res: + return None + + nfields = _get_nfields(res) + if nfields is None: + return None + + enc = cursor._encoding + return [ + res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr] + ] + + +def _get_nfields(res: "PGresult") -> Optional[int]: + """ + Return the number of columns in a result, if it returns tuples else None + + Take into account the special case of results with zero columns. + """ + nfields = res.nfields + + if ( + res.status == TUPLES_OK + or res.status == SINGLE_TUPLE + # "describe" in named cursors + or (res.status == COMMAND_OK and nfields) + ): + return nfields + else: + return None diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py new file mode 100644 index 0000000..b890d77 --- /dev/null +++ b/psycopg/psycopg/server_cursor.py @@ -0,0 +1,479 @@ +""" +psycopg server-side cursor objects. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, AsyncIterator, List, Iterable, Iterator +from typing import Optional, TypeVar, TYPE_CHECKING, overload +from warnings import warn + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .rows import Row, RowFactory, AsyncRowFactory +from .cursor import BaseCursor, Cursor +from .generators import execute +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + +DEFAULT_ITERSIZE = 100 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + + +class ServerCursorMixin(BaseCursor[ConnectionType, Row]): + """Mixin to add ServerCursor behaviour and implementation a BaseCursor.""" + + __slots__ = "_name _scrollable _withhold _described itersize _format".split() + + def __init__( + self, + name: str, + scrollable: Optional[bool], + withhold: bool, + ): + self._name = name + self._scrollable = scrollable + self._withhold = withhold + self._described = False + self.itersize: int = DEFAULT_ITERSIZE + self._format = TEXT + + def __repr__(self) -> str: + # Insert the name as the second word + parts = super().__repr__().split(None, 1) + parts.insert(1, f"{self._name!r}") + return " ".join(parts) + + @property + def name(self) -> str: + """The name of the cursor.""" + return self._name + + @property + def scrollable(self) -> Optional[bool]: + """ + Whether the cursor is scrollable or not. + + If `!None` leave the choice to the server. Use `!True` if you want to + use `scroll()` on the cursor. + """ + return self._scrollable + + @property + def withhold(self) -> bool: + """ + If the cursor can be used after the creating transaction has committed. + """ + return self._withhold + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + res = self.pgresult + # command_status is empty if the result comes from + # describe_portal, which means that we have just executed the DECLARE, + # so we can assume we are at the first row. + tuples = res and (res.status == TUPLES_OK or res.command_status == b"") + return self._pos if tuples else None + + def _declare_gen( + self, + query: Query, + params: Optional[Params] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `ServerCursor.execute()`.""" + + query = self._make_declare_statement(query) + + # If the cursor is being reused, the previous one must be closed. + if self._described: + yield from self._close_gen() + self._described = False + + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, force_extended=True) + results = yield from execute(self._conn.pgconn) + if results[-1].status != COMMAND_OK: + self._raise_for_result(results[-1]) + + # Set the format, which will be used by describe and fetch operations + if binary is None: + self._format = self.format + else: + self._format = BINARY if binary else TEXT + + # The above result only returned COMMAND_OK. Get the cursor shape + yield from self._describe_gen() + + def _describe_gen(self) -> PQGen[None]: + self._pgconn.send_describe_portal(self._name.encode(self._encoding)) + results = yield from execute(self._pgconn) + self._check_results(results) + self._results = results + self._select_current_result(0, format=self._format) + self._described = True + + def _close_gen(self) -> PQGen[None]: + ts = self._conn.pgconn.transaction_status + + # if the connection is not in a sane state, don't even try + if ts != IDLE and ts != INTRANS: + return + + # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already. + if not self._withhold and ts == IDLE: + return + + # if we didn't declare the cursor ourselves we still have to close it + # but we must make sure it exists. + if not self._described: + query = sql.SQL( + "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}" + ).format(sql.Literal(self._name)) + res = yield from self._conn._exec_command(query) + # pipeline mode otherwise, unsupported here. + assert res is not None + if res.ntuples == 0: + return + + query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name)) + yield from self._conn._exec_command(query) + + def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]: + if self.closed: + raise e.InterfaceError("the cursor is closed") + # If we are stealing the cursor, make sure we know its shape + if not self._described: + yield from self._start_query() + yield from self._describe_gen() + + query = sql.SQL("FETCH FORWARD {} FROM {}").format( + sql.SQL("ALL") if num is None else sql.Literal(num), + sql.Identifier(self._name), + ) + res = yield from self._conn._exec_command(query, result_format=self._format) + # pipeline mode otherwise, unsupported here. + assert res is not None + + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=False) + return self._tx.load_rows(0, res.ntuples, self._make_row) + + def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: + if mode not in ("relative", "absolute"): + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + query = sql.SQL("MOVE{} {} FROM {}").format( + sql.SQL(" ABSOLUTE" if mode == "absolute" else ""), + sql.Literal(value), + sql.Identifier(self._name), + ) + yield from self._conn._exec_command(query) + + def _make_declare_statement(self, query: Query) -> sql.Composed: + + if isinstance(query, bytes): + query = query.decode(self._encoding) + if not isinstance(query, sql.Composable): + query = sql.SQL(query) + + parts = [ + sql.SQL("DECLARE"), + sql.Identifier(self._name), + ] + if self._scrollable is not None: + parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL")) + parts.append(sql.SQL("CURSOR")) + if self._withhold: + parts.append(sql.SQL("WITH HOLD")) + parts.append(sql.SQL("FOR")) + parts.append(query) + + return sql.SQL(" ").join(parts) + + +class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="ServerCursor[Any]") + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Any]", + name: str, + *, + row_factory: RowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + name: str, + *, + row_factory: Optional[RowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + Cursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + self._conn.wait(self._close_gen()) + super().close() + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + """ + Open a cursor to execute a query to the database. + """ + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + with self._conn.lock: + self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + """Method not implemented for server-side cursors.""" + raise e.NotSupportedError("executemany not supported on server-side cursors") + + def fetchone(self) -> Optional[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + def fetchall(self) -> List[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + def __iter__(self) -> Iterator[Row]: + while True: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + def scroll(self, value: int, mode: str = "relative") -> None: + with self._conn.lock: + self._conn.wait(self._scroll_gen(value, mode)) + # Postgres doesn't have a reliable way to report a cursor out of bound + if mode == "relative": + self._pos += value + else: + self._pos = value + + +class AsyncServerCursor( + ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]") + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: AsyncRowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + AsyncCursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + async def close(self) -> None: + async with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + await self._conn.wait(self._close_gen()) + await super().close() + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + async with self._conn.lock: + await self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + raise e.NotSupportedError("executemany not supported on server-side cursors") + + async def fetchone(self) -> Optional[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + async def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + async def fetchall(self) -> List[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + async def __aiter__(self) -> AsyncIterator[Row]: + while True: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + async def scroll(self, value: int, mode: str = "relative") -> None: + async with self._conn.lock: + await self._conn.wait(self._scroll_gen(value, mode)) diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py new file mode 100644 index 0000000..099a01c --- /dev/null +++ b/psycopg/psycopg/sql.py @@ -0,0 +1,467 @@ +""" +SQL composition utility module +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +import string +from abc import ABC, abstractmethod +from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union + +from .pq import Escaping +from .abc import AdaptContext +from .adapt import Transformer, PyFormat +from ._compat import LiteralString +from ._encodings import conn_encoding + + +def quote(obj: Any, context: Optional[AdaptContext] = None) -> str: + """ + Adapt a Python object to a quoted SQL string. + + Use this function only if you absolutely want to convert a Python string to + an SQL quoted literal to use e.g. to generate batch SQL and you won't have + a connection available when you will need to use it. + + This function is relatively inefficient, because it doesn't cache the + adaptation rules. If you pass a `!context` you can adapt the adaptation + rules used, otherwise only global rules are used. + + """ + return Literal(obj).as_string(context) + + +class Composable(ABC): + """ + Abstract base class for objects that can be used to compose an SQL string. + + `!Composable` objects can be passed directly to + `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`, + `~psycopg.Cursor.copy()` in place of the query string. + + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. + """ + + def __init__(self, obj: Any): + self._obj = obj + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._obj!r})" + + @abstractmethod + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + """ + Return the value of the object as bytes. + + :param context: the context to evaluate the object into. + :type context: `connection` or `cursor` + + The method is automatically invoked by `~psycopg.Cursor.execute()`, + `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a + `!Composable` is passed instead of the query string. + + """ + raise NotImplementedError + + def as_string(self, context: Optional[AdaptContext]) -> str: + """ + Return the value of the object as string. + + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` + + """ + conn = context.connection if context else None + enc = conn_encoding(conn) + b = self.as_bytes(context) + if isinstance(b, bytes): + return b.decode(enc) + else: + # buffer object + return codecs.lookup(enc).decode(b)[0] + + def __add__(self, other: "Composable") -> "Composed": + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composable): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + def __mul__(self, n: int) -> "Composed": + return Composed([self] * n) + + def __eq__(self, other: Any) -> bool: + return type(self) is type(other) and self._obj == other._obj + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Composed(Composable): + """ + A `Composable` object made of a sequence of `!Composable`. + + The object is usually created using `!Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of objects as arguments: if they are not `!Composable` they will + be wrapped in a `Literal`. + + Example:: + + >>> comp = sql.Composed( + ... [sql.SQL("INSERT INTO "), sql.Identifier("table")]) + >>> print(comp.as_string(conn)) + INSERT INTO "table" + + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). + """ + + _obj: List[Composable] + + def __init__(self, seq: Sequence[Any]): + seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq] + super().__init__(seq) + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + return b"".join(obj.as_bytes(context) for obj in self._obj) + + def __iter__(self) -> Iterator[Composable]: + return iter(self._obj) + + def __add__(self, other: Composable) -> "Composed": + if isinstance(other, Composed): + return Composed(self._obj + other._obj) + if isinstance(other, Composable): + return Composed(self._obj + [other]) + else: + return NotImplemented + + def join(self, joiner: Union["SQL", LiteralString]) -> "Composed": + """ + Return a new `!Composed` interposing the `!joiner` with the `!Composed` items. + + The `!joiner` must be a `SQL` or a string which will be interpreted as + an `SQL`. + + Example:: + + >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed + >>> print(fields.join(', ').as_string(conn)) + "foo", "bar" + + """ + if isinstance(joiner, str): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be strings or SQL," + f" got {joiner!r} instead" + ) + + return joiner.join(self._obj) + + +class SQL(Composable): + """ + A `Composable` representing a snippet of SQL statement. + + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). + + The `!obj` string doesn't undergo any form of escaping, so it is not + suitable to represent variable identifiers or values: you should only use + it to pass constant strings representing templates or snippets of SQL + statements; use other objects such as `Identifier` or `Literal` to + represent variable parts. + + Example:: + + >>> query = sql.SQL("SELECT {0} FROM {1}").format( + ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), + ... sql.Identifier('table')) + >>> print(query.as_string(conn)) + SELECT "foo", "bar" FROM "table" + """ + + _obj: LiteralString + _formatter = string.Formatter() + + def __init__(self, obj: LiteralString): + super().__init__(obj) + if not isinstance(obj, str): + raise TypeError(f"SQL values must be strings, got {obj!r} instead") + + def as_string(self, context: Optional[AdaptContext]) -> str: + return self._obj + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + enc = "utf-8" + if context: + enc = conn_encoding(context.connection) + return self._obj.encode(enc) + + def format(self, *args: Any, **kwargs: Any) -> Composed: + """ + Merge `Composable` objects into a template. + + :param args: parameters to replace to numbered (``{0}``, ``{1}``) or + auto-numbered (``{}``) placeholders + :param kwargs: parameters to replace to named (``{name}``) placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``), numbered (``{0}``, + ``{1}``...), and named placeholders (``{name}``), with positional + arguments replacing the numbered placeholders and keywords replacing + the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) + are not supported. + + If a `!Composable` objects is passed to the template it will be merged + according to its `as_string()` method. If any other Python object is + passed, it will be wrapped in a `Literal` object and so escaped + according to SQL rules. + + Example:: + + >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + SELECT * FROM "people" WHERE "id" = %s + + >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}") + ... .format(tbl=sql.Identifier('people'), name="O'Rourke")) + ... .as_string(conn)) + SELECT * FROM "people" WHERE name = 'O''Rourke' + + """ + rv: List[Composable] = [] + autonum: Optional[int] = 0 + # TODO: this is probably not the right way to whitelist pre + # pyre complains. Will wait for mypy to complain too to fix. + pre: LiteralString + for pre, name, spec, conv in self._formatter.parse(self._obj): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual" + ) + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic" + ) + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + + def join(self, seq: Iterable[Composable]) -> Composed: + """ + Join a sequence of `Composable`. + + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's string to separate the elements in `!seq`. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. + + Example:: + + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) + >>> print(snip.as_string(conn)) + "foo", "bar", "baz" + """ + rv = [] + it = iter(seq) + try: + rv.append(next(it)) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composable): + """ + A `Composable` representing an SQL identifier or a dot-separated sequence. + + Identifiers usually represent names of database objects, such as tables or + fields. PostgreSQL identifiers follow `different rules`__ than SQL string + literals for escaping (e.g. they use double quotes instead of single). + + .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \ + SQL-SYNTAX-IDENTIFIERS + + Example:: + + >>> t1 = sql.Identifier("foo") + >>> t2 = sql.Identifier("ba'r") + >>> t3 = sql.Identifier('ba"z') + >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) + "foo", "ba'r", "ba""z" + + Multiple strings can be passed to the object to represent a qualified name, + i.e. a dot-separated sequence of identifiers. + + Example:: + + >>> query = sql.SQL("SELECT {} FROM {}").format( + ... sql.Identifier("table", "field"), + ... sql.Identifier("schema", "table")) + >>> print(query.as_string(conn)) + SELECT "table"."field" FROM "schema"."table" + + """ + + _obj: Sequence[str] + + def __init__(self, *strings: str): + # init super() now to make the __repr__ not explode in case of error + super().__init__(strings) + + if not strings: + raise TypeError("Identifier cannot be empty") + + for s in strings: + if not isinstance(s, str): + raise TypeError( + f"SQL identifier parts must be strings, got {s!r} instead" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + if not conn: + raise ValueError("a connection is necessary for Identifier") + esc = Escaping(conn.pgconn) + enc = conn_encoding(conn) + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + return b".".join(escs) + + +class Literal(Composable): + """ + A `Composable` representing an SQL value to include in a query. + + Usually you will want to include placeholders in the query and pass values + as `~cursor.execute()` arguments. If however you really really need to + include a literal value in the query you can use this object. + + The string returned by `!as_string()` follows the normal :ref:`adaptation + rules <types-adaptation>` for Python objects. + + Example:: + + >>> s1 = sql.Literal("fo'o") + >>> s2 = sql.Literal(42) + >>> s3 = sql.Literal(date(2000, 1, 1)) + >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) + 'fo''o', 42, '2000-01-01'::date + + """ + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + tx = Transformer.from_context(context) + return tx.as_literal(self._obj) + + +class Placeholder(Composable): + """A `Composable` representing a placeholder for query parameters. + + If the name is specified, generate a named placeholder (e.g. ``%(name)s``, + ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``, + ``%b``). + + The object is useful to generate SQL queries with a variable number of + arguments. + + Examples:: + + >>> names = ['foo', 'bar', 'baz'] + + >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) + >>> print(q1.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s) + + >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) + >>> print(q2.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s) + + """ + + def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO): + super().__init__(name) + if not isinstance(name, str): + raise TypeError(f"expected string as name, got {name!r}") + + if ")" in name: + raise ValueError(f"invalid name: {name!r}") + + if type(format) is str: + format = PyFormat(format) + if not isinstance(format, PyFormat): + raise TypeError( + f"expected PyFormat as format, got {type(format).__name__!r}" + ) + + self._format: PyFormat = format + + def __repr__(self) -> str: + parts = [] + if self._obj: + parts.append(repr(self._obj)) + if self._format is not PyFormat.AUTO: + parts.append(f"format={self._format.name}") + + return f"{self.__class__.__name__}({', '.join(parts)})" + + def as_string(self, context: Optional[AdaptContext]) -> str: + code = self._format.value + return f"%({self._obj}){code}" if self._obj else f"%{code}" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + enc = conn_encoding(conn) + return self.as_string(context).encode(enc) + + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py new file mode 100644 index 0000000..e13486e --- /dev/null +++ b/psycopg/psycopg/transaction.py @@ -0,0 +1,290 @@ +""" +Transaction context managers returned by Connection.transaction() +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from types import TracebackType +from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, PQGen + +if TYPE_CHECKING: + from typing import Any + from .connection import Connection + from .connection_async import AsyncConnection + +IDLE = pq.TransactionStatus.IDLE + +OK = pq.ConnStatus.OK + +logger = logging.getLogger(__name__) + + +class Rollback(Exception): + """ + Exit the current `Transaction` context immediately and rollback any changes + made within this context. + + If a transaction context is specified in the constructor, rollback + enclosing transactions contexts up to and including the one specified. + """ + + __module__ = "psycopg" + + def __init__( + self, + transaction: Union["Transaction", "AsyncTransaction", None] = None, + ): + self.transaction = transaction + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}({self.transaction!r})" + + +class OutOfOrderTransactionNesting(e.ProgrammingError): + """Out-of-order transaction nesting detected""" + + +class BaseTransaction(Generic[ConnectionType]): + def __init__( + self, + connection: ConnectionType, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ): + self._conn = connection + self.pgconn = self._conn.pgconn + self._savepoint_name = savepoint_name or "" + self.force_rollback = force_rollback + self._entered = self._exited = False + self._outer_transaction = False + self._stack_index = -1 + + @property + def savepoint_name(self) -> Optional[str]: + """ + The name of the savepoint; `!None` if handling the main transaction. + """ + # Yes, it may change on __enter__. No, I don't care, because the + # un-entered state is outside the public interface. + return self._savepoint_name + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + if not self._entered: + status = "inactive" + elif not self._exited: + status = "active" + else: + status = "terminated" + + sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" + return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" + + def _enter_gen(self) -> PQGen[None]: + if self._entered: + raise TypeError("transaction blocks can be used only once") + self._entered = True + + self._push_savepoint() + for command in self._get_enter_commands(): + yield from self._conn._exec_command(command) + + def _exit_gen( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> PQGen[bool]: + if not exc_val and not self.force_rollback: + yield from self._commit_gen() + return False + else: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + return (yield from self._rollback_gen(exc_val)) + except OutOfOrderTransactionNesting: + # Clobber an exception happened in the block with the exception + # caused by out-of-order transaction detected, so make the + # behaviour consistent with _commit_gen and to make sure the + # user fixes this condition, which is unrelated from + # operational error that might arise in the block. + raise + except Exception as exc2: + logger.warning("error ignored in rollback of %s: %s", self, exc2) + return False + + def _commit_gen(self) -> PQGen[None]: + ex = self._pop_savepoint("commit") + self._exited = True + if ex: + raise ex + + for command in self._get_commit_commands(): + yield from self._conn._exec_command(command) + + def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: + if isinstance(exc_val, Rollback): + logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True) + + ex = self._pop_savepoint("rollback") + self._exited = True + if ex: + raise ex + + for command in self._get_rollback_commands(): + yield from self._conn._exec_command(command) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False + + def _get_enter_commands(self) -> Iterator[bytes]: + if self._outer_transaction: + yield self._conn._get_tx_start_command() + + if self._savepoint_name: + yield ( + sql.SQL("SAVEPOINT {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + def _get_commit_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("RELEASE {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"COMMIT" + + def _get_rollback_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("ROLLBACK TO {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + yield ( + sql.SQL("RELEASE {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"ROLLBACK" + + # Also clear the prepared statements cache. + if self._conn._prepared.clear(): + yield from self._conn._prepared.get_maintenance_commands() + + def _push_savepoint(self) -> None: + """ + Push the transaction on the connection transactions stack. + + Also set the internal state of the object and verify consistency. + """ + self._outer_transaction = self.pgconn.transaction_status == IDLE + if self._outer_transaction: + # outer transaction: if no name it's only a begin, else + # there will be an additional savepoint + assert not self._conn._num_transactions + else: + # inner transaction: it always has a name + if not self._savepoint_name: + self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}" + + self._stack_index = self._conn._num_transactions + self._conn._num_transactions += 1 + + def _pop_savepoint(self, action: str) -> Optional[Exception]: + """ + Pop the transaction from the connection transactions stack. + + Also verify the state consistency. + """ + self._conn._num_transactions -= 1 + if self._conn._num_transactions == self._stack_index: + return None + + return OutOfOrderTransactionNesting( + f"transaction {action} at the wrong nesting level: {self}" + ) + + +class Transaction(BaseTransaction["Connection[Any]"]): + """ + Returned by `Connection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="Transaction") + + @property + def connection(self) -> "Connection[Any]": + """The connection the object is managing.""" + return self._conn + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + with self._conn.lock: + return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False + + +class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): + """ + Returned by `AsyncConnection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="AsyncTransaction") + + @property + def connection(self) -> "AsyncConnection[Any]": + return self._conn + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + async with self._conn.lock: + return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py new file mode 100644 index 0000000..bdddf05 --- /dev/null +++ b/psycopg/psycopg/types/__init__.py @@ -0,0 +1,11 @@ +""" +psycopg types package +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import _typeinfo + +# Exposed here +TypeInfo = _typeinfo.TypeInfo +TypesRegistry = _typeinfo.TypesRegistry diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py new file mode 100644 index 0000000..e35c5e7 --- /dev/null +++ b/psycopg/psycopg/types/array.py @@ -0,0 +1,464 @@ +""" +Adapters for arrays +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type + +from .. import pq +from .. import errors as e +from .. import postgres +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._compat import cache, prod +from .._struct import pack_len, unpack_len +from .._cmodule import _psycopg +from ..postgres import TEXT_OID, INVALID_OID +from .._typeinfo import TypeInfo + +_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid +_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack) +_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from) +_struct_dim = struct.Struct("!II") # dim, lower bound +_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack) +_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from) + +TEXT_ARRAY_OID = postgres.types["text"].array_oid + +PY_TEXT = PyFormat.TEXT +PQ_BINARY = pq.Format.BINARY + + +class BaseListDumper(RecursiveDumper): + element_oid = 0 + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + if cls is NoneType: + cls = list + + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + if self.element_oid and context: + sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format) + self.sub_dumper = sdclass(NoneType, context) + + def _find_list_element(self, L: List[Any], format: PyFormat) -> Any: + """ + Find the first non-null element of an eventually nested list + """ + items = list(self._flatiter(L, set())) + types = {type(item): item for item in items} + if not types: + return None + + if len(types) == 1: + t, v = types.popitem() + else: + # More than one type in the list. It might be still good, as long + # as they dump with the same oid (e.g. IPv4Network, IPv6Network). + dumpers = [self._tx.get_dumper(item, format) for item in types.values()] + oids = set(d.oid for d in dumpers) + if len(oids) == 1: + t, v = types.popitem() + else: + raise e.DataError( + "cannot dump lists of mixed types;" + f" got: {', '.join(sorted(t.__name__ for t in types))}" + ) + + # Checking for precise type. If the type is a subclass (e.g. Int4) + # we assume the user knows what type they are passing. + if t is not int: + return v + + # If we got an int, let's see what is the biggest one in order to + # choose the smallest OID and allow Postgres to do the right cast. + imax: int = max(items) + imin: int = min(items) + if imin >= 0: + return imax + else: + return max(imax, -imin - 1) + + def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: + if id(L) in seen: + raise e.DataError("cannot dump a recursive list") + + seen.add(id(L)) + + for item in L: + if type(item) is list: + yield from self._flatiter(item, seen) + elif item is not None: + yield item + + return None + + def _get_base_type_info(self, base_oid: int) -> TypeInfo: + """ + Return info about the base type. + + Return text info as fallback. + """ + if base_oid: + info = self._tx.adapters.types.get(base_oid) + if info: + return info + + return self._tx.adapters.types["text"] + + +class ListDumper(BaseListDumper): + + delimiter = b"," + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return self.cls + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + # Empty lists can only be dumped as text if the type is unknown. + return self + + sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + + # We consider an array of unknowns as unknown, so we can dump empty + # lists or lists containing only None elements. + if sd.oid != INVALID_OID: + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + dumper.delimiter = info.delimiter.encode() + else: + dumper.oid = INVALID_OID + + return dumper + + # Double quotes and backslashes embedded in element values will be + # backslash-escaped. + _re_esc = re.compile(rb'(["\\])') + + def dump(self, obj: List[Any]) -> bytes: + tokens: List[Buffer] = [] + needs_quotes = _get_needs_quotes_regexp(self.delimiter).search + + def dump_list(obj: List[Any]) -> None: + if not obj: + tokens.append(b"{}") + return + + tokens.append(b"{") + for item in obj: + if isinstance(item, list): + dump_list(item) + elif item is not None: + ad = self._dump_item(item) + if needs_quotes(ad): + if not isinstance(ad, bytes): + ad = bytes(ad) + ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"' + tokens.append(ad) + else: + tokens.append(b"NULL") + + tokens.append(self.delimiter) + + tokens[-1] = b"}" + + dump_list(obj) + + return b"".join(tokens) + + def _dump_item(self, item: Any) -> Buffer: + if self.sub_dumper: + return self.sub_dumper.dump(item) + else: + return self._tx.get_dumper(item, PY_TEXT).dump(item) + + +@cache +def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]: + """Return a regexp to recognise when a value needs quotes + + from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO + + The array output routine will put double quotes around element values if + they are empty strings, contain curly braces, delimiter characters, + double quotes, backslashes, or white space, or match the word NULL. + """ + return re.compile( + rb"""(?xi) + ^$ # the empty string + | ["{}%s\\\s] # or a char to escape + | ^null$ # or the word NULL + """ + % delimiter + ) + + +class ListBinaryDumper(BaseListDumper): + + format = pq.Format.BINARY + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return (self.cls,) + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + return ListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + + return dumper + + def dump(self, obj: List[Any]) -> bytes: + # Postgres won't take unknown for element oid: fall back on text + sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID + + if not obj: + return _pack_head(0, 0, sub_oid) + + data: List[Buffer] = [b"", b""] # placeholders to avoid a resize + dims: List[int] = [] + hasnull = 0 + + def calc_dims(L: List[Any]) -> None: + if isinstance(L, self.cls): + if not L: + raise e.DataError("lists cannot contain empty lists") + dims.append(len(L)) + calc_dims(L[0]) + + calc_dims(obj) + + def dump_list(L: List[Any], dim: int) -> None: + nonlocal hasnull + if len(L) != dims[dim]: + raise e.DataError("nested lists have inconsistent lengths") + + if dim == len(dims) - 1: + for item in L: + if item is not None: + # If we get here, the sub_dumper must have been set + ad = self.sub_dumper.dump(item) # type: ignore[union-attr] + data.append(pack_len(len(ad))) + data.append(ad) + else: + hasnull = 1 + data.append(b"\xff\xff\xff\xff") + else: + for item in L: + if not isinstance(item, self.cls): + raise e.DataError("nested lists have inconsistent depths") + dump_list(item, dim + 1) # type: ignore + + dump_list(obj, 0) + + data[0] = _pack_head(len(dims), hasnull, sub_oid) + data[1] = b"".join(_pack_dim(dim, 1) for dim in dims) + return b"".join(data) + + +class ArrayLoader(RecursiveLoader): + + delimiter = b"," + base_oid: int + + def load(self, data: Buffer) -> List[Any]: + loader = self._tx.get_loader(self.base_oid, self.format) + return _load_text(data, loader, self.delimiter) + + +class ArrayBinaryLoader(RecursiveLoader): + + format = pq.Format.BINARY + + def load(self, data: Buffer) -> List[Any]: + return _load_binary(data, self._tx) + + +def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + if not info.array_oid: + raise ValueError(f"the type info {info} doesn't describe an array") + + base: Type[Any] + adapters = context.adapters if context else postgres.adapters + + base = getattr(_psycopg, "ArrayLoader", ArrayLoader) + name = f"{info.name.title()}{base.__name__}" + attribs = { + "base_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + loader = type(name, (base,), attribs) + adapters.register_loader(info.array_oid, loader) + + loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader) + adapters.register_loader(info.array_oid, loader) + + base = ListDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + base = ListBinaryDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + +def register_default_adapters(context: AdaptContext) -> None: + # The text dumper is more flexible as it can handle lists of mixed type, + # so register it later. + context.adapters.register_dumper(list, ListBinaryDumper) + context.adapters.register_dumper(list, ListDumper) + + +def register_all_arrays(context: AdaptContext) -> None: + """ + Associate the array oid of all the types in Loader.globals. + + This function is designed to be called once at import time, after having + registered all the base loaders. + """ + for t in context.adapters.types: + if t.array_oid: + t.register(context) + + +def _load_text( + data: Buffer, + loader: Loader, + delimiter: bytes = b",", + __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"), +) -> List[Any]: + rv = None + stack: List[Any] = [] + a: List[Any] = [] + rv = a + load = loader.load + + # Remove the dimensions information prefix (``[...]=``) + if data and data[0] == b"["[0]: + if isinstance(data, memoryview): + data = bytes(data) + idx = data.find(b"=") + if idx == -1: + raise e.DataError("malformed array: no '=' after dimension information") + data = data[idx + 1 :] + + re_parse = _get_array_parse_regexp(delimiter) + for m in re_parse.finditer(data): + t = m.group(1) + if t == b"{": + if stack: + stack[-1].append(a) + stack.append(a) + a = [] + + elif t == b"}": + if not stack: + raise e.DataError("malformed array: unexpected '}'") + rv = stack.pop() + + else: + if not stack: + wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else "" + raise e.DataError(f"malformed array: unexpected '{wat}'") + if t == b"NULL": + v = None + else: + if t.startswith(b'"'): + t = __re_unescape.sub(rb"\1", t[1:-1]) + v = load(t) + + stack[-1].append(v) + + assert rv is not None + return rv + + +@cache +def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: + """ + Return a regexp to tokenize an array representation into item and brackets + """ + return re.compile( + rb"""(?xi) + ( [{}] # open or closed bracket + | " (?: [^"\\] | \\. )* " # or a quoted string + | [^"{}%s\\]+ # or an unquoted non-empty string + ) ,? + """ + % delimiter + ) + + +def _load_binary(data: Buffer, tx: Transformer) -> List[Any]: + ndims, hasnull, oid = _unpack_head(data) + load = tx.get_loader(oid, PQ_BINARY).load + + if not ndims: + return [] + + p = 12 + 8 * ndims + dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)] + nelems = prod(dims) + + out: List[Any] = [None] * nelems + for i in range(nelems): + size = unpack_len(data, p)[0] + p += 4 + if size == -1: + continue + out[i] = load(data[p : p + size]) + p += size + + # fon ndims > 1 we have to aggregate the array into sub-arrays + for dim in dims[-1:0:-1]: + out = [out[i : i + dim] for i in range(0, len(out), dim)] + + return out diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py new file mode 100644 index 0000000..db7e181 --- /dev/null +++ b/psycopg/psycopg/types/bool.py @@ -0,0 +1,51 @@ +""" +Adapters for booleans. +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + + +class BoolDumper(Dumper): + + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"t" if obj else b"f" + + def quote(self, obj: bool) -> bytes: + return b"true" if obj else b"false" + + +class BoolBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"\x01" if obj else b"\x00" + + +class BoolLoader(Loader): + def load(self, data: Buffer) -> bool: + return data == b"t" + + +class BoolBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> bool: + return data != b"\x00" + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(bool, BoolDumper) + adapters.register_dumper(bool, BoolBinaryDumper) + adapters.register_loader("bool", BoolLoader) + adapters.register_loader("bool", BoolBinaryLoader) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py new file mode 100644 index 0000000..1c609c3 --- /dev/null +++ b/psycopg/psycopg/types/composite.py @@ -0,0 +1,290 @@ +""" +Support for composite types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from collections import namedtuple +from typing import Any, Callable, cast, Iterator, List, Optional +from typing import Sequence, Tuple, Type + +from .. import pq +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader +from .._struct import pack_len, unpack_len +from ..postgres import TEXT_OID +from .._typeinfo import CompositeInfo as CompositeInfo # exported here +from .._encodings import _as_python_identifier + +_struct_oidlen = struct.Struct("!Ii") +_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) +_unpack_oidlen = cast( + Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from +) + + +class SequenceDumper(RecursiveDumper): + def _dump_sequence( + self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes + ) -> bytes: + if not obj: + return start + end + + parts: List[Buffer] = [start] + + for item in obj: + if item is None: + parts.append(sep) + continue + + dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + ad = dumper.dump(item) + if not ad: + ad = b'""' + elif self._re_needs_quotes.search(ad): + ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"' + + parts.append(ad) + parts.append(sep) + + parts[-1] = end + + return b"".join(parts) + + _re_needs_quotes = re.compile(rb'[",\\\s()]') + _re_esc = re.compile(rb"([\\\"])") + + +class TupleDumper(SequenceDumper): + + # Should be this, but it doesn't work + # oid = postgres_types["record"].oid + + def dump(self, obj: Tuple[Any, ...]) -> bytes: + return self._dump_sequence(obj, b"(", b")", b",") + + +class TupleBinaryDumper(RecursiveDumper): + + format = pq.Format.BINARY + + # Subclasses must set an info + info: CompositeInfo + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + nfields = len(self.info.field_types) + self._tx.set_dumper_types(self.info.field_types, self.format) + self._formats = (PyFormat.from_pq(self.format),) * nfields + + def dump(self, obj: Tuple[Any, ...]) -> bytearray: + out = bytearray(pack_len(len(obj))) + adapted = self._tx.dump_sequence(obj, self._formats) + for i in range(len(obj)): + b = adapted[i] + oid = self.info.field_types[i] + if b is not None: + out += _pack_oidlen(oid, len(b)) + out += b + else: + out += _pack_oidlen(oid, -1) + + return out + + +class BaseCompositeLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]: + """ + Split a non-empty representation of a composite type into components. + + Terminators shouldn't be used in `!data` (so that both record and range + representations can be parsed). + """ + for m in self._re_tokenize.finditer(data): + if m.group(1): + yield None + elif m.group(2) is not None: + yield self._re_undouble.sub(rb"\1", m.group(2)) + else: + yield m.group(3) + + # If the final group ended in `,` there is a final NULL in the record + # that the regexp couldn't parse. + if m and m.group().endswith(b","): + yield None + + _re_tokenize = re.compile( + rb"""(?x) + (,) # an empty token, representing NULL + | " ((?: [^"] | "")*) " ,? # or a quoted string + | ([^",)]+) ,? # or an unquoted string + """ + ) + + _re_undouble = re.compile(rb'(["\\])\1') + + +class RecordLoader(BaseCompositeLoader): + def load(self, data: Buffer) -> Tuple[Any, ...]: + if data == b"()": + return () + + cast = self._tx.get_loader(TEXT_OID, self.format).load + return tuple( + cast(token) if token is not None else None + for token in self._parse_record(data[1:-1]) + ) + + +class RecordBinaryLoader(Loader): + format = pq.Format.BINARY + _types_set = False + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def load(self, data: Buffer) -> Tuple[Any, ...]: + if not self._types_set: + self._config_types(data) + self._types_set = True + + return self._tx.load_sequence( + tuple( + data[offset : offset + length] if length != -1 else None + for _, offset, length in self._walk_record(data) + ) + ) + + def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]: + """ + Yield a sequence of (oid, offset, length) for the content of the record + """ + nfields = unpack_len(data, 0)[0] + i = 4 + for _ in range(nfields): + oid, length = _unpack_oidlen(data, i) + yield oid, i + 8, length + i += (8 + length) if length > 0 else 8 + + def _config_types(self, data: Buffer) -> None: + oids = [r[0] for r in self._walk_record(data)] + self._tx.set_loader_types(oids, self.format) + + +class CompositeLoader(RecordLoader): + + factory: Callable[..., Any] + fields_types: List[int] + _types_set = False + + def load(self, data: Buffer) -> Any: + if not self._types_set: + self._config_types(data) + self._types_set = True + + if data == b"()": + return type(self).factory() + + return type(self).factory( + *self._tx.load_sequence(tuple(self._parse_record(data[1:-1]))) + ) + + def _config_types(self, data: Buffer) -> None: + self._tx.set_loader_types(self.fields_types, self.format) + + +class CompositeBinaryLoader(RecordBinaryLoader): + + format = pq.Format.BINARY + factory: Callable[..., Any] + + def load(self, data: Buffer) -> Any: + r = super().load(data) + return type(self).factory(*r) + + +def register_composite( + info: CompositeInfo, + context: Optional[AdaptContext] = None, + factory: Optional[Callable[..., Any]] = None, +) -> None: + """Register the adapters to load and dump a composite type. + + :param info: The object with the information about the composite to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param factory: Callable to convert the sequence of attributes read from + the composite into a Python object. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested composite available?") + + # Register arrays and type info + info.register(context) + + if not factory: + factory = namedtuple( # type: ignore + _as_python_identifier(info.name), + [_as_python_identifier(n) for n in info.field_names], + ) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[BaseCompositeLoader] = type( + f"{info.name.title()}Loader", + (CompositeLoader,), + { + "factory": factory, + "fields_types": info.field_types, + }, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + loader = type( + f"{info.name.title()}BinaryLoader", + (CompositeBinaryLoader,), + {"factory": factory}, + ) + adapters.register_loader(info.oid, loader) + + # If the factory is a type, create and register dumpers for it + if isinstance(factory, type): + dumper = type( + f"{info.name.title()}BinaryDumper", + (TupleBinaryDumper,), + {"oid": info.oid, "info": info}, + ) + adapters.register_dumper(factory, dumper) + + # Default to the text dumper because it is more flexible + dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid}) + adapters.register_dumper(factory, dumper) + + info.python_type = factory + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(tuple, TupleDumper) + adapters.register_loader("record", RecordLoader) + adapters.register_loader("record", RecordBinaryLoader) diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py new file mode 100644 index 0000000..f0dfe83 --- /dev/null +++ b/psycopg/psycopg/types/datetime.py @@ -0,0 +1,754 @@ +""" +Adapters for date/time types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from datetime import date, datetime, time, timedelta, timezone +from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from .._tz import get_tzinfo +from ..abc import AdaptContext, DumperKey +from ..adapt import Buffer, Dumper, Loader, PyFormat +from ..errors import InterfaceError, DataError +from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8 + +if TYPE_CHECKING: + from ..connection import BaseConnection + +_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset +_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack) +_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack) + +_struct_interval = struct.Struct("!qii") # microseconds, days, months +_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack) +_unpack_interval = cast( + Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack +) + +utc = timezone.utc +_pg_date_epoch_days = date(2000, 1, 1).toordinal() +_pg_datetime_epoch = datetime(2000, 1, 1) +_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc) +_py_date_min_days = date.min.toordinal() + + +class DateDumper(Dumper): + + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DateBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + days = obj.toordinal() - _pg_date_epoch_days + return pack_int4(days) + + +class _BaseTimeDumper(Dumper): + def get_key(self, obj: time, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade to a dumper for timetz (the + # Frankenstein of the data types). + if not obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseTimeTextDumper(_BaseTimeDumper): + def dump(self, obj: time) -> bytes: + return str(obj).encode() + + +class TimeDumper(_BaseTimeTextDumper): + + oid = postgres.types["time"].oid + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzDumper(self.cls) + + +class TimeTzDumper(_BaseTimeTextDumper): + + oid = postgres.types["timetz"].oid + + +class TimeBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["time"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + return pack_int8(us) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzBinaryDumper(self.cls) + + +class TimeTzBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["timetz"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + off = obj.utcoffset() + assert off is not None + return _pack_timetz(us, -int(off.total_seconds())) + + +class _BaseDatetimeDumper(Dumper): + def get_key(self, obj: datetime, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade (downgrade, actually) to a + # dumper for naive timestamp. + if obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseDatetimeTextDumper(_BaseDatetimeDumper): + def dump(self, obj: datetime) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DatetimeDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamptz"].oid + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzDumper(self.cls) + + +class DatetimeNoTzDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamp"].oid + + +class DatetimeBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamptz"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetimetz_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzBinaryDumper(self.cls) + + +class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamp"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetime_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + +class TimedeltaDumper(Dumper): + + oid = postgres.types["interval"].oid + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + if self.connection: + if ( + self.connection.pgconn.parameter_status(b"IntervalStyle") + == b"sql_standard" + ): + setattr(self, "dump", self._dump_sql) + + def dump(self, obj: timedelta) -> bytes: + # The comma is parsed ok by PostgreSQL but it's not documented + # and it seems brittle to rely on it. CRDB doesn't consume it well. + return str(obj).encode().replace(b",", b"") + + def _dump_sql(self, obj: timedelta) -> bytes: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + return b"%+d day %+d second %+d microsecond" % ( + obj.days, + obj.seconds, + obj.microseconds, + ) + + +class TimedeltaBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["interval"].oid + + def dump(self, obj: timedelta) -> bytes: + micros = 1_000_000 * obj.seconds + obj.microseconds + return _pack_interval(micros, obj.days, 0) + + +class DateLoader(Loader): + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> date: + if self._order == self._ORDER_YMD: + ye = data[:4] + mo = data[5:7] + da = data[8:] + elif self._order == self._ORDER_DMY: + da = data[:2] + mo = data[3:5] + ye = data[6:] + else: + mo = data[:2] + da = data[3:5] + ye = data[6:] + + try: + return date(int(ye), int(mo), int(da)) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + if s == "infinity" or (s and len(s.split()[0]) > 10): + raise DataError(f"date too large (after year 10K): {s!r}") from None + elif s == "-infinity" or "BC" in s: + raise DataError(f"date too small (before year 1): {s!r}") from None + else: + raise DataError(f"can't parse date {s!r}: {ex}") from None + + +class DateBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> date: + days = unpack_int4(data)[0] + _pg_date_epoch_days + try: + return date.fromordinal(days) + except (ValueError, OverflowError): + if days < _py_date_min_days: + raise DataError("date too small (before year 1)") from None + else: + raise DataError("date too large (after year 10K)") from None + + +class TimeLoader(Loader): + + _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?") + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}") + + ho, mi, se, fr = m.groups() + + # Pad the fraction of second to get micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return time(int(ho), int(mi), int(se), us) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}: {e}") from None + + +class TimeBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val = unpack_int8(data)[0] + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + try: + return time(h, m, s, us) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimetzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}") + + ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone + off = 60 * 60 * int(oh) + if om: + off += 60 * int(om) + if os: + off += int(os) + tz = timezone(timedelta(0, off if sgn == b"+" else -off)) + + try: + return time(int(ho), int(mi), int(se), us, tz) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}: {e}") from None + + +class TimetzBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val, off = _unpack_timetz(data) + + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + + try: + return time(h, m, s, us, timezone(timedelta(seconds=-off))) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimestampLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + $ + """ + ) + _re_format_pg = re.compile( + rb"""(?ix) + ^ + [a-z]+ [^a-z0-9] # DoW, separator + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + [^a-z0-9] (\d+) # Year + $ + """ + ) + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + _ORDER_PGDM = 3 + _ORDER_PGMD = 4 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S"): # SQL + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + elif ds.startswith(b"P"): # Postgres + self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD + self._re_format = self._re_format_pg + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + if self._order == self._ORDER_YMD: + ye, mo, da, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_DMY: + da, mo, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_MDY: + mo, da, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + else: + if self._order == self._ORDER_PGDM: + da, mo, ho, mi, se, fr, ye = m.groups() + else: + mo, da, ho, mi, se, fr, ye = m.groups() + try: + imo = _month_abbr[mo] + except KeyError: + s = mo.decode("utf8", "replace") + raise DataError(f"can't parse month: {s!r}") from None + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us) + except ValueError as ex: + raise _get_timestamp_load_error(self.connection, data, ex) from None + + +class TimestampBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + return _pg_datetime_epoch + timedelta(microseconds=micros) + except OverflowError: + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class TimestamptzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + ds = _get_datestyle(self.connection) + if not ds.startswith(b"I"): # not ISO + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone offset + soff = 60 * 60 * int(oh) + if om: + soff += 60 * int(om) + if os: + soff += int(os) + tzoff = timedelta(0, soff if sgn == b"+" else -soff) + + # The return value is a datetime with the timezone of the connection + # (in order to be consistent with the binary loader, which is the only + # thing it can return). So create a temporary datetime object, in utc, + # shift it by the offset parsed from the timestamp, and then move it to + # the connection timezone. + dt = None + ex: Exception + try: + dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc) + return (dt - tzoff).astimezone(self._timezone) + except OverflowError as e: + # If we have created the temporary 'dt' it means that we have a + # datetime close to max, the shift pushed it past max, overflowing. + # In this case return the datetime in a fixed offset timezone. + if dt is not None: + return dt.replace(tzinfo=timezone(tzoff)) + else: + ex = e + except ValueError as e: + ex = e + + raise _get_timestamp_load_error(self.connection, data, ex) from None + + def _load_notimpl(self, data: Buffer) -> datetime: + s = bytes(data).decode("utf8", "replace") + ds = _get_datestyle(self.connection).decode("ascii") + raise NotImplementedError( + f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" + ) + + +class TimestamptzBinaryLoader(Loader): + + format = Format.BINARY + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + ts = _pg_datetimetz_epoch + timedelta(microseconds=micros) + return ts.astimezone(self._timezone) + except OverflowError: + # If we were asked about a timestamp which would overflow in UTC, + # but not in the desired timezone (e.g. datetime.max at Chicago + # timezone) we can still save the day by shifting the value by the + # timezone offset and then replacing the timezone. + if self._timezone: + utcoff = self._timezone.utcoffset( + datetime.min if micros < 0 else datetime.max + ) + if utcoff: + usoff = 1_000_000 * int(utcoff.total_seconds()) + try: + ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff) + except OverflowError: + pass # will raise downstream + else: + return ts.replace(tzinfo=self._timezone) + + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class IntervalLoader(Loader): + + _re_interval = re.compile( + rb""" + (?: ([-+]?\d+) \s+ years? \s* )? # Years + (?: ([-+]?\d+) \s+ mons? \s* )? # Months + (?: ([-+]?\d+) \s+ days? \s* )? # Days + (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time + )? + """, + re.VERBOSE, + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if self.connection: + ints = self.connection.pgconn.parameter_status(b"IntervalStyle") + if ints != b"postgres": + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> timedelta: + m = self._re_interval.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}") + + ye, mo, da, sgn, ho, mi, se = m.groups() + days = 0 + seconds = 0.0 + + if ye: + days += 365 * int(ye) + if mo: + days += 30 * int(mo) + if da: + days += int(da) + + if ho: + seconds = 3600 * int(ho) + 60 * int(mi) + float(se) + if sgn == b"-": + seconds = -seconds + + try: + return timedelta(days=days, seconds=seconds) + except OverflowError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}: {e}") from None + + def _load_notimpl(self, data: Buffer) -> timedelta: + s = bytes(data).decode("utf8", "replace") + ints = ( + self.connection + and self.connection.pgconn.parameter_status(b"IntervalStyle") + or b"unknown" + ).decode("utf8", "replace") + raise NotImplementedError( + f"can't parse interval with IntervalStyle {ints}: {s!r}" + ) + + +class IntervalBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> timedelta: + micros, days, months = _unpack_interval(data) + if months > 0: + years, months = divmod(months, 12) + days = days + 30 * months + 365 * years + elif months < 0: + years, months = divmod(-months, 12) + days = days - 30 * months - 365 * years + + try: + return timedelta(days=days, microseconds=micros) + except OverflowError as e: + raise DataError(f"can't parse interval: {e}") from None + + +def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes: + if conn: + ds = conn.pgconn.parameter_status(b"DateStyle") + if ds: + return ds + + return b"ISO, DMY" + + +def _get_timestamp_load_error( + conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None +) -> Exception: + s = bytes(data).decode("utf8", "replace") + + def is_overflow(s: str) -> bool: + if not s: + return False + + ds = _get_datestyle(conn) + if not ds.startswith(b"P"): # Postgres + return len(s.split()[0]) > 10 # date is first token + else: + return len(s.split()[-1]) > 4 # year is last token + + if s == "-infinity" or s.endswith("BC"): + return DataError("timestamp too small (before year 1): {s!r}") + elif s == "infinity" or is_overflow(s): + return DataError(f"timestamp too large (after year 10K): {s!r}") + else: + return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}") + + +_month_abbr = { + n: i + for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1) +} + +# Pad to get microseconds from a fraction of seconds +_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1] + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("datetime.date", DateDumper) + adapters.register_dumper("datetime.date", DateBinaryDumper) + + # first register dumpers for 'timetz' oid, then the proper ones on time type. + adapters.register_dumper("datetime.time", TimeTzDumper) + adapters.register_dumper("datetime.time", TimeTzBinaryDumper) + adapters.register_dumper("datetime.time", TimeDumper) + adapters.register_dumper("datetime.time", TimeBinaryDumper) + + # first register dumpers for 'timestamp' oid, then the proper ones + # on the datetime type. + adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper) + adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper) + adapters.register_dumper("datetime.datetime", DatetimeDumper) + adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper) + + adapters.register_dumper("datetime.timedelta", TimedeltaDumper) + adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper) + + adapters.register_loader("date", DateLoader) + adapters.register_loader("date", DateBinaryLoader) + adapters.register_loader("time", TimeLoader) + adapters.register_loader("time", TimeBinaryLoader) + adapters.register_loader("timetz", TimetzLoader) + adapters.register_loader("timetz", TimetzBinaryLoader) + adapters.register_loader("timestamp", TimestampLoader) + adapters.register_loader("timestamp", TimestampBinaryLoader) + adapters.register_loader("timestamptz", TimestamptzLoader) + adapters.register_loader("timestamptz", TimestamptzBinaryLoader) + adapters.register_loader("interval", IntervalLoader) + adapters.register_loader("interval", IntervalBinaryLoader) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py new file mode 100644 index 0000000..d3c7387 --- /dev/null +++ b/psycopg/psycopg/types/enum.py @@ -0,0 +1,177 @@ +""" +Adapters for the enum type. +""" +from enum import Enum +from typing import Any, Dict, Generic, Optional, Mapping, Sequence +from typing import Tuple, Type, TypeVar, Union, cast +from typing_extensions import TypeAlias + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from .._encodings import conn_encoding +from .._typeinfo import EnumInfo as EnumInfo # exported here + +E = TypeVar("E", bound=Enum) + +EnumDumpMap: TypeAlias = Dict[E, bytes] +EnumLoadMap: TypeAlias = Dict[bytes, E] +EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None] + + +class _BaseEnumLoader(Loader, Generic[E]): + """ + Loader for a specific Enum class + """ + + enum: Type[E] + _load_map: EnumLoadMap[E] + + def load(self, data: Buffer) -> E: + if not isinstance(data, bytes): + data = bytes(data) + + try: + return self._load_map[data] + except KeyError: + enc = conn_encoding(self.connection) + label = data.decode(enc, "replace") + raise e.DataError( + f"bad member for enum {self.enum.__qualname__}: {label!r}" + ) + + +class _BaseEnumDumper(Dumper, Generic[E]): + """ + Dumper for a specific Enum class + """ + + enum: Type[E] + _dump_map: EnumDumpMap[E] + + def dump(self, value: E) -> Buffer: + return self._dump_map[value] + + +class EnumDumper(Dumper): + """ + Dumper for a generic Enum class + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._encoding = conn_encoding(self.connection) + + def dump(self, value: E) -> Buffer: + return value.name.encode(self._encoding) + + +class EnumBinaryDumper(EnumDumper): + format = Format.BINARY + + +def register_enum( + info: EnumInfo, + context: Optional[AdaptContext] = None, + enum: Optional[Type[E]] = None, + *, + mapping: EnumMapping[E] = None, +) -> None: + """Register the adapters to load and dump a enum type. + + :param info: The object with the information about the enum to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param enum: Python enum type matching to the PostgreSQL one. If `!None`, + a new enum will be generated and exposed as `EnumInfo.enum`. + :param mapping: Override the mapping between `!enum` members and `!info` + labels. + """ + + if not info: + raise TypeError("no info passed. Is the requested enum available?") + + if enum is None: + enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__)) + + info.enum = enum + adapters = context.adapters if context else postgres.adapters + info.register(context) + + load_map = _make_load_map(info, enum, mapping, context) + attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map} + + name = f"{info.name.title()}Loader" + loader = type(name, (_BaseEnumLoader,), attribs) + adapters.register_loader(info.oid, loader) + + name = f"{info.name.title()}BinaryLoader" + loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY}) + adapters.register_loader(info.oid, loader) + + dump_map = _make_dump_map(info, enum, mapping, context) + attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map} + + name = f"{enum.__name__}Dumper" + dumper = type(name, (_BaseEnumDumper,), attribs) + adapters.register_dumper(info.enum, dumper) + + name = f"{enum.__name__}BinaryDumper" + dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY}) + adapters.register_dumper(info.enum, dumper) + + +def _make_load_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumLoadMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumLoadMap[E] = {} + for label in info.labels: + try: + member = enum[label] + except KeyError: + # tolerate a missing enum, assuming it won't be used. If it is we + # will get a DataError on fetch. + pass + else: + rv[label.encode(enc)] = member + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[label.encode(enc)] = member + + return rv + + +def _make_dump_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumDumpMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumDumpMap[E] = {} + for member in enum: + rv[member] = member.name.encode(enc) + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[member] = label.encode(enc) + + return rv + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, EnumBinaryDumper) + context.adapters.register_dumper(Enum, EnumDumper) diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py new file mode 100644 index 0000000..e1ab1d5 --- /dev/null +++ b/psycopg/psycopg/types/hstore.py @@ -0,0 +1,131 @@ +""" +Dict to hstore adaptation +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +from typing import Dict, List, Optional +from typing_extensions import TypeAlias + +from .. import errors as e +from .. import postgres +from ..abc import Buffer, AdaptContext +from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader +from ..postgres import TEXT_OID +from .._typeinfo import TypeInfo + +_re_escape = re.compile(r'(["\\])') +_re_unescape = re.compile(r"\\(.)") + +_re_hstore = re.compile( + r""" + # hstore key: + # a string of normal or escaped chars + "((?: [^"\\] | \\. )*)" + \s*=>\s* # hstore value + (?: + NULL # the value can be null - not caught + # or a quoted string like the key + | "((?: [^"\\] | \\. )*)" + ) + (?:\s*,\s*|$) # pairs separated by comma or end of string. +""", + re.VERBOSE, +) + + +Hstore: TypeAlias = Dict[str, Optional[str]] + + +class BaseHstoreDumper(RecursiveDumper): + def dump(self, obj: Hstore) -> Buffer: + if not obj: + return b"" + + tokens: List[str] = [] + + def add_token(s: str) -> None: + tokens.append('"') + tokens.append(_re_escape.sub(r"\\\1", s)) + tokens.append('"') + + for k, v in obj.items(): + + if not isinstance(k, str): + raise e.DataError("hstore keys can only be strings") + add_token(k) + + tokens.append("=>") + + if v is None: + tokens.append("NULL") + elif not isinstance(v, str): + raise e.DataError("hstore keys can only be strings") + else: + add_token(v) + + tokens.append(",") + + del tokens[-1] + data = "".join(tokens) + dumper = self._tx.get_dumper(data, PyFormat.TEXT) + return dumper.dump(data) + + +class HstoreLoader(RecursiveLoader): + def load(self, data: Buffer) -> Hstore: + loader = self._tx.get_loader(TEXT_OID, self.format) + s: str = loader.load(data) + + rv: Hstore = {} + start = 0 + for m in _re_hstore.finditer(s): + if m is None or m.start() != start: + raise e.DataError(f"error parsing hstore pair at char {start}") + k = _re_unescape.sub(r"\1", m.group(1)) + v = m.group(2) + if v is not None: + v = _re_unescape.sub(r"\1", v) + + rv[k] = v + start = m.end() + + if start < len(s): + raise e.DataError(f"error parsing hstore: unparsed data after char {start}") + + return rv + + +def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump hstore. + + :param info: The object with the information about the hstore type. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'hstore' extension loaded?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # Generate and register a customized text dumper + class HstoreDumper(BaseHstoreDumper): + oid = info.oid + + adapters.register_dumper(dict, HstoreDumper) + + # register the text loader on the oid + adapters.register_loader(info.oid, HstoreLoader) diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py new file mode 100644 index 0000000..a80e0e4 --- /dev/null +++ b/psycopg/psycopg/types/json.py @@ -0,0 +1,232 @@ +""" +Adapers for JSON types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import json +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +from .. import abc +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap +from ..errors import DataError + +JsonDumpsFunction = Callable[[Any], str] +JsonLoadsFunction = Callable[[Union[str, bytes]], Any] + + +def set_json_dumps( + dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON serialisation function to store JSON objects in the database. + + :param dumps: The dump function to use. + :type dumps: `!Callable[[Any], str]` + :param context: Where to use the `!dumps` function. If not specified, use it + globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default dumping JSON uses the builtin `json.dumps`. You can override + it to use a different JSON library or to use customised arguments. + + If the `Json` wrapper specified a `!dumps` function, use it in precedence + of the one set by this function. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonDumper._dumps = dumps + else: + adapters = context.adapters + + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + (Json, PyFormat.BINARY), + (Json, PyFormat.TEXT), + (Jsonb, PyFormat.BINARY), + (Jsonb, PyFormat.TEXT), + ] + dumper: Type[_JsonDumper] + for wrapper, format in grid: + base = _get_current_dumper(adapters, wrapper, format) + name = base.__name__ + if not base.__name__.startswith("Custom"): + name = f"Custom{name}" + dumper = type(name, (base,), {"_dumps": dumps}) + adapters.register_dumper(wrapper, dumper) + + +def set_json_loads( + loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON parsing function to fetch JSON objects from the database. + + :param loads: The load function to use. + :type loads: `!Callable[[bytes], Any]` + :param context: Where to use the `!loads` function. If not specified, use + it globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default loading JSON uses the builtin `json.loads`. You can override + it to use a different JSON library or to use customised arguments. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonLoader._loads = loads + else: + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + ("json", JsonLoader), + ("json", JsonBinaryLoader), + ("jsonb", JsonbLoader), + ("jsonb", JsonbBinaryLoader), + ] + loader: Type[_JsonLoader] + for tname, base in grid: + loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads}) + context.adapters.register_loader(tname, loader) + + +class _JsonWrapper: + __slots__ = ("obj", "dumps") + + def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None): + self.obj = obj + self.dumps = dumps + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} chars)" + return f"{self.__class__.__name__}({sobj})" + + +class Json(_JsonWrapper): + __slots__ = () + + +class Jsonb(_JsonWrapper): + __slots__ = () + + +class _JsonDumper(Dumper): + + # The globally used JSON dumps() function. It can be changed globally (by + # set_json_dumps) or by a subclass. + _dumps: JsonDumpsFunction = json.dumps + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self.dumps = self.__class__._dumps + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return dumps(obj.obj).encode() + + +class JsonDumper(_JsonDumper): + + oid = postgres.types["json"].oid + + +class JsonBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["json"].oid + + +class JsonbDumper(_JsonDumper): + + oid = postgres.types["jsonb"].oid + + +class JsonbBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["jsonb"].oid + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return b"\x01" + dumps(obj.obj).encode() + + +class _JsonLoader(Loader): + + # The globally used JSON loads() function. It can be changed globally (by + # set_json_loads) or by a subclass. + _loads: JsonLoadsFunction = json.loads + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self.loads = self.__class__._loads + + def load(self, data: Buffer) -> Any: + # json.loads() cannot work on memoryview. + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +class JsonLoader(_JsonLoader): + pass + + +class JsonbLoader(_JsonLoader): + pass + + +class JsonBinaryLoader(_JsonLoader): + format = Format.BINARY + + +class JsonbBinaryLoader(_JsonLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Any: + if data and data[0] != 1: + raise DataError("unknown jsonb binary format: {data[0]}") + data = data[1:] + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +def _get_current_dumper( + adapters: AdaptersMap, cls: type, format: PyFormat +) -> Type[abc.Dumper]: + try: + return adapters.get_dumper(cls, format) + except e.ProgrammingError: + return _default_dumpers[cls, format] + + +_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = { + (Json, PyFormat.BINARY): JsonBinaryDumper, + (Json, PyFormat.TEXT): JsonDumper, + (Jsonb, PyFormat.BINARY): JsonbBinaryDumper, + (Jsonb, PyFormat.TEXT): JsonDumper, +} + + +def register_default_adapters(context: abc.AdaptContext) -> None: + adapters = context.adapters + + # Currently json binary format is nothing different than text, maybe with + # an extra memcopy we can avoid. + adapters.register_dumper(Json, JsonBinaryDumper) + adapters.register_dumper(Json, JsonDumper) + adapters.register_dumper(Jsonb, JsonbBinaryDumper) + adapters.register_dumper(Jsonb, JsonbDumper) + adapters.register_loader("json", JsonLoader) + adapters.register_loader("jsonb", JsonbLoader) + adapters.register_loader("json", JsonBinaryLoader) + adapters.register_loader("jsonb", JsonbBinaryLoader) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py new file mode 100644 index 0000000..3eaa7f1 --- /dev/null +++ b/psycopg/psycopg/types/multirange.py @@ -0,0 +1,514 @@ +""" +Support for multirange types adaptation. +""" + +# Copyright (C) 2021 The Psycopg Team + +from decimal import Decimal +from typing import Any, Generic, List, Iterable +from typing import MutableSequence, Optional, Type, Union, overload +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here + +from .range import Range, T, load_range_text, load_range_binary +from .range import dump_range_text, dump_range_binary, fail_dump + + +class Multirange(MutableSequence[Range[T]]): + """Python representation for a PostgreSQL multirange type. + + :param items: Sequence of ranges to initialise the object. + """ + + def __init__(self, items: Iterable[Range[T]] = ()): + self._ranges: List[Range[T]] = list(map(self._check_type, items)) + + def _check_type(self, item: Any) -> Range[Any]: + if not isinstance(item, Range): + raise TypeError( + f"Multirange is a sequence of Range, got {type(item).__name__}" + ) + return item + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._ranges!r})" + + def __str__(self) -> str: + return f"{{{', '.join(map(str, self._ranges))}}}" + + @overload + def __getitem__(self, index: int) -> Range[T]: + ... + + @overload + def __getitem__(self, index: slice) -> "Multirange[T]": + ... + + def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]": + if isinstance(index, int): + return self._ranges[index] + else: + return Multirange(self._ranges[index]) + + def __len__(self) -> int: + return len(self._ranges) + + @overload + def __setitem__(self, index: int, value: Range[T]) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: + ... + + def __setitem__( + self, + index: Union[int, slice], + value: Union[Range[T], Iterable[Range[T]]], + ) -> None: + if isinstance(index, int): + self._check_type(value) + self._ranges[index] = self._check_type(value) + elif not isinstance(value, Iterable): + raise TypeError("can only assign an iterable") + else: + value = map(self._check_type, value) + self._ranges[index] = value + + def __delitem__(self, index: Union[int, slice]) -> None: + del self._ranges[index] + + def insert(self, index: int, value: Range[T]) -> None: + self._ranges.insert(index, self._check_type(value)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return False + return self._ranges == other._ranges + + # Order is arbitrary but consistent + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges < other._ranges + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges > other._ranges + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + +# Subclasses to specify a specific subtype. Usually not needed + + +class Int4Multirange(Multirange[int]): + pass + + +class Int8Multirange(Multirange[int]): + pass + + +class NumericMultirange(Multirange[Decimal]): + pass + + +class DateMultirange(Multirange[date]): + pass + + +class TimestampMultirange(Multirange[datetime]): + pass + + +class TimestamptzMultirange(Multirange[datetime]): + pass + + +class BaseMultirangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self + + item = self._get_item(obj) + if item is None: + return MultirangeDumper(self.cls) + + dumper: BaseMultirangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = MultirangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_multirange_oid(TEXT_OID) + else: + dumper.oid = self._get_multirange_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Multirange[Any]) -> Any: + """ + Return a member representative of the multirange + """ + for r in obj: + if r.lower is not None: + return r.lower + if r.upper is not None: + return r.upper + return None + + def _get_multirange_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class MultirangeDumper(BaseMultirangeDumper): + """ + Dumper for multirange types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Multirange[Any]) -> Buffer: + if not obj: + return b"{}" + + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [b"{"] + for r in obj: + out.append(dump_range_text(r, dump)) + out.append(b",") + out[-1] = b"}" + return b"".join(out) + + +class MultirangeBinaryDumper(BaseMultirangeDumper): + + format = Format.BINARY + + def dump(self, obj: Multirange[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [pack_len(len(obj))] + for r in obj: + data = dump_range_binary(r, dump) + out.append(pack_len(len(data))) + out.append(data) + return b"".join(out) + + +class BaseMultirangeLoader(RecursiveLoader, Generic[T]): + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class MultirangeLoader(BaseMultirangeLoader[T]): + def load(self, data: Buffer) -> Multirange[T]: + if not data or data[0] != _START_INT: + raise e.DataError( + "malformed multirange starting with" + f" {bytes(data[:1]).decode('utf8', 'replace')}" + ) + + out = Multirange[T]() + if data == b"{}": + return out + + pos = 1 + data = data[pos:] + try: + while True: + r, pos = load_range_text(data, self._load) + out.append(r) + + sep = data[pos] # can raise IndexError + if sep == _SEP_INT: + data = data[pos + 1 :] + continue + elif sep == _END_INT: + if len(data) == pos + 1: + return out + else: + raise e.DataError( + "malformed multirange: data after closing brace" + ) + else: + raise e.DataError( + f"malformed multirange: found unexpected {chr(sep)}" + ) + + except IndexError: + raise e.DataError("malformed multirange: separator missing") + + return out + + +_SEP_INT = ord(",") +_START_INT = ord("{") +_END_INT = ord("}") + + +class MultirangeBinaryLoader(BaseMultirangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Multirange[T]: + nelems = unpack_len(data, 0)[0] + pos = 4 + out = Multirange[T]() + for i in range(nelems): + length = unpack_len(data, pos)[0] + pos += 4 + out.append(load_range_binary(data[pos : pos + length], self._load)) + pos += length + + if pos != len(data): + raise e.DataError("unexpected trailing data in multirange") + + return out + + +def register_multirange( + info: MultirangeInfo, context: Optional[AdaptContext] = None +) -> None: + """Register the adapters to load and dump a multirange type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested multirange available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[MultirangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (MultirangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[MultirangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (MultirangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeDumper(MultirangeDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeDumper(MultirangeDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeDumper(MultirangeDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeDumper(MultirangeDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeDumper(MultirangeDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeDumper(MultirangeDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Binary dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Text loaders for builtin multirange types + + +class Int4MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeLoader(MultirangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeLoader(MultirangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin multirange types + + +class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Multirange, MultirangeBinaryDumper) + adapters.register_dumper(Multirange, MultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeDumper) + adapters.register_dumper(DateMultirange, DateMultirangeDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper) + adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper) + adapters.register_loader("int4multirange", Int4MultirangeLoader) + adapters.register_loader("int8multirange", Int8MultirangeLoader) + adapters.register_loader("nummultirange", NumericMultirangeLoader) + adapters.register_loader("datemultirange", DateMultirangeLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader) + adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader) + adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader) + adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader) + adapters.register_loader("datemultirange", DateMultirangeBinaryLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader) diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py new file mode 100644 index 0000000..2f2c05b --- /dev/null +++ b/psycopg/psycopg/types/net.py @@ -0,0 +1,206 @@ +""" +Adapters for network types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, Type, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import ipaddress + +Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"] +Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"] +Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"] + +# These objects will be imported lazily +ip_address: Callable[[str], Address] = None # type: ignore[assignment] +ip_interface: Callable[[str], Interface] = None # type: ignore[assignment] +ip_network: Callable[[str], Network] = None # type: ignore[assignment] +IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment] +IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment] +IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment] +IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment] +IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment] +IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment] + +PGSQL_AF_INET = 2 +PGSQL_AF_INET6 = 3 +IPV4_PREFIXLEN = 32 +IPV6_PREFIXLEN = 128 + + +class _LazyIpaddress: + def _ensure_module(self) -> None: + global ip_address, ip_interface, ip_network + global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface + global IPv4Network, IPv6Network + + if ip_address is None: + from ipaddress import ip_address, ip_interface, ip_network + from ipaddress import IPv4Address, IPv6Address + from ipaddress import IPv4Interface, IPv6Interface + from ipaddress import IPv4Network, IPv6Network + + +class InterfaceDumper(Dumper): + + oid = postgres.types["inet"].oid + + def dump(self, obj: Interface) -> bytes: + return str(obj).encode() + + +class NetworkDumper(Dumper): + + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + return str(obj).encode() + + +class _AIBinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["inet"].oid + + +class AddressBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Address) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.max_prefixlen, 0, len(packed))) + return head + packed + + +class InterfaceBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Interface) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.network.prefixlen, 0, len(packed))) + return head + packed + + +class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress): + """Either an address or an interface to inet + + Used when looking up by oid. + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._ensure_module() + + def dump(self, obj: Union[Address, Interface]) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + if isinstance(obj, (IPv4Interface, IPv6Interface)): + prefixlen = obj.network.prefixlen + else: + prefixlen = obj.max_prefixlen + + head = bytes((family, prefixlen, 0, len(packed))) + return head + packed + + +class NetworkBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + packed = obj.network_address.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.prefixlen, 1, len(packed))) + return head + packed + + +class _LazyIpaddressLoader(Loader, _LazyIpaddress): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._ensure_module() + + +class InetLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + if b"/" in data: + return ip_interface(data.decode()) + else: + return ip_address(data.decode()) + + +class InetBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + if prefix == IPV4_PREFIXLEN: + return IPv4Address(packed) + else: + return IPv4Interface((packed, prefix)) + else: + if prefix == IPV6_PREFIXLEN: + return IPv6Address(packed) + else: + return IPv6Interface((packed, prefix)) + + +class CidrLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + return ip_network(data.decode()) + + +class CidrBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + return IPv4Network((packed, prefix)) + else: + return IPv6Network((packed, prefix)) + + return ip_network(data.decode()) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper) + adapters.register_dumper(None, InetBinaryDumper) + adapters.register_loader("inet", InetLoader) + adapters.register_loader("inet", InetBinaryLoader) + adapters.register_loader("cidr", CidrLoader) + adapters.register_loader("cidr", CidrBinaryLoader) diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py new file mode 100644 index 0000000..2ab857c --- /dev/null +++ b/psycopg/psycopg/types/none.py @@ -0,0 +1,25 @@ +""" +Adapters for None. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ..abc import AdaptContext, NoneType +from ..adapt import Dumper + + +class NoneDumper(Dumper): + """ + Not a complete dumper as it doesn't implement dump(), but it implements + quote(), so it can be used in sql composition. + """ + + def dump(self, obj: None) -> bytes: + raise NotImplementedError("NULL is passed to Postgres in other ways") + + def quote(self, obj: None) -> bytes: + return b"NULL" + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, NoneDumper) diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py new file mode 100644 index 0000000..1bd9329 --- /dev/null +++ b/psycopg/psycopg/types/numeric.py @@ -0,0 +1,515 @@ +""" +Adapers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from math import log +from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast +from decimal import Decimal, DefaultContext, Context + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader, PyFormat +from .._struct import pack_int2, pack_uint2, unpack_int2 +from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4 +from .._struct import pack_int8, unpack_int8 +from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8 + +# Exposed here +from .._wrappers import ( + Int2 as Int2, + Int4 as Int4, + Int8 as Int8, + IntNumeric as IntNumeric, + Oid as Oid, + Float4 as Float4, + Float8 as Float8, +) + + +class _IntDumper(Dumper): + def dump(self, obj: Any) -> Buffer: + t = type(obj) + if t is not int: + # Convert to int in order to dump IntEnum correctly + if issubclass(t, int): + obj = int(obj) + else: + raise e.DataError(f"integer expected, got {type(obj).__name__!r}") + + return str(obj).encode() + + def quote(self, obj: Any) -> Buffer: + value = self.dump(obj) + return value if obj >= 0 else b" " + value + + +class _SpecialValuesDumper(Dumper): + + _special: Dict[bytes, bytes] = {} + + def dump(self, obj: Any) -> bytes: + return str(obj).encode() + + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + + if value in self._special: + return self._special[value] + + return value if obj >= 0 else b" " + value + + +class FloatDumper(_SpecialValuesDumper): + + oid = postgres.types["float8"].oid + + _special = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", + } + + +class Float4Dumper(FloatDumper): + oid = postgres.types["float4"].oid + + +class FloatBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["float8"].oid + + def dump(self, obj: float) -> bytes: + return pack_float8(obj) + + +class Float4BinaryDumper(FloatBinaryDumper): + + oid = postgres.types["float4"].oid + + def dump(self, obj: float) -> bytes: + return pack_float4(obj) + + +class DecimalDumper(_SpecialValuesDumper): + + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> bytes: + if obj.is_nan(): + # cover NaN and sNaN + return b"NaN" + else: + return str(obj).encode() + + _special = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", + } + + +class Int2Dumper(_IntDumper): + oid = postgres.types["int2"].oid + + +class Int4Dumper(_IntDumper): + oid = postgres.types["int4"].oid + + +class Int8Dumper(_IntDumper): + oid = postgres.types["int8"].oid + + +class IntNumericDumper(_IntDumper): + oid = postgres.types["numeric"].oid + + +class OidDumper(_IntDumper): + oid = postgres.types["oid"].oid + + +class IntDumper(Dumper): + def dump(self, obj: Any) -> bytes: + raise TypeError( + f"{type(self).__name__} is a dispatcher to other dumpers:" + " dump() is not supposed to be called" + ) + + def get_key(self, obj: int, format: PyFormat) -> type: + return self.upgrade(obj, format).cls + + _int2_dumper = Int2Dumper(Int2) + _int4_dumper = Int4Dumper(Int4) + _int8_dumper = Int8Dumper(Int8) + _int_numeric_dumper = IntNumericDumper(IntNumeric) + + def upgrade(self, obj: int, format: PyFormat) -> Dumper: + if -(2**31) <= obj < 2**31: + if -(2**15) <= obj < 2**15: + return self._int2_dumper + else: + return self._int4_dumper + else: + if -(2**63) <= obj < 2**63: + return self._int8_dumper + else: + return self._int_numeric_dumper + + +class Int2BinaryDumper(Int2Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int2(obj) + + +class Int4BinaryDumper(Int4Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int4(obj) + + +class Int8BinaryDumper(Int8Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int8(obj) + + +# Ratio between number of bits required to store a number and number of pg +# decimal digits required. +BIT_PER_PGDIGIT = log(2) / log(10_000) + + +class IntNumericBinaryDumper(IntNumericDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> Buffer: + return dump_int_to_numeric_binary(obj) + + +class OidBinaryDumper(OidDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_uint4(obj) + + +class IntBinaryDumper(IntDumper): + + format = Format.BINARY + + _int2_dumper = Int2BinaryDumper(Int2) + _int4_dumper = Int4BinaryDumper(Int4) + _int8_dumper = Int8BinaryDumper(Int8) + _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric) + + +class IntLoader(Loader): + def load(self, data: Buffer) -> int: + # it supports bytes directly + return int(data) + + +class Int2BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int2(data)[0] + + +class Int4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int4(data)[0] + + +class Int8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int8(data)[0] + + +class OidBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_uint4(data)[0] + + +class FloatLoader(Loader): + def load(self, data: Buffer) -> float: + # it supports bytes directly + return float(data) + + +class Float4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float4(data)[0] + + +class Float8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float8(data)[0] + + +class NumericLoader(Loader): + def load(self, data: Buffer) -> Decimal: + if isinstance(data, memoryview): + data = bytes(data) + return Decimal(data.decode()) + + +DEC_DIGITS = 4 # decimal digits per Postgres "digit" +NUMERIC_POS = 0x0000 +NUMERIC_NEG = 0x4000 +NUMERIC_NAN = 0xC000 +NUMERIC_PINF = 0xD000 +NUMERIC_NINF = 0xF000 + +_decimal_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + + +class _ContextMap(DefaultDict[int, Context]): + """ + Cache for decimal contexts to use when the precision requires it. + + Note: if the default context is used (prec=28) you can get an invalid + operation or a rounding to 0: + + - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000') + - Decimal(1000).shift(25) = Decimal('0') + - Decimal(1000).shift(30) raises InvalidOperation + """ + + def __missing__(self, key: int) -> Context: + val = Context(prec=key) + self[key] = val + return val + + +_contexts = _ContextMap() +for i in range(DefaultContext.prec): + _contexts[i] = DefaultContext + +_unpack_numeric_head = cast( + Callable[[Buffer], Tuple[int, int, int, int]], + struct.Struct("!HhHH").unpack_from, +) +_pack_numeric_head = cast( + Callable[[int, int, int, int], bytes], + struct.Struct("!HhHH").pack, +) + + +class NumericBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Decimal: + ndigits, weight, sign, dscale = _unpack_numeric_head(data) + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + val = 0 + for i in range(8, len(data), 2): + val = val * 10_000 + data[i] * 0x100 + data[i + 1] + + shift = dscale - (ndigits - weight - 1) * DEC_DIGITS + ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale] + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, ctx) + .shift(shift, ctx) + ) + else: + try: + return _decimal_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None + + +NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0) +NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0) +NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0) + + +class DecimalBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> Buffer: + return dump_decimal_to_numeric_binary(obj) + + +class NumericDumper(DecimalDumper): + def dump(self, obj: Union[Decimal, int]) -> bytes: + if isinstance(obj, int): + return str(obj).encode() + else: + return super().dump(obj) + + +class NumericBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Union[Decimal, int]) -> Buffer: + if isinstance(obj, int): + return dump_int_to_numeric_binary(obj) + else: + return dump_decimal_to_numeric_binary(obj) + + +def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]: + sign, digits, exp = obj.as_tuple() + if exp == "n" or exp == "N": # type: ignore[comparison-overlap] + return NUMERIC_NAN_BIN + elif exp == "F": # type: ignore[comparison-overlap] + return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN + + # Weights of py digits into a pg digit according to their positions. + # Starting with an index wi != 0 is equivalent to prepending 0's to + # the digits tuple, but without really changing it. + weights = (1000, 100, 10, 1) + wi = 0 + + ndigits = nzdigits = len(digits) + + # Find the last nonzero digit + while nzdigits > 0 and digits[nzdigits - 1] == 0: + nzdigits -= 1 + + if exp <= 0: + dscale = -exp + else: + dscale = 0 + # align the py digits to the pg digits if there's some py exponent + ndigits += exp % DEC_DIGITS + + if not nzdigits: + return _pack_numeric_head(0, 0, NUMERIC_POS, dscale) + + # Equivalent of 0-padding left to align the py digits to the pg digits + # but without changing the digits tuple. + mod = (ndigits - dscale) % DEC_DIGITS + if mod: + wi = DEC_DIGITS - mod + ndigits += wi + + tmp = nzdigits + wi + out = bytearray( + _pack_numeric_head( + tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits + (ndigits + exp) // DEC_DIGITS - 1, # weight + NUMERIC_NEG if sign else NUMERIC_POS, # sign + dscale, + ) + ) + + pgdigit = 0 + for i in range(nzdigits): + pgdigit += weights[wi] * digits[i] + wi += 1 + if wi >= DEC_DIGITS: + out += pack_uint2(pgdigit) + pgdigit = wi = 0 + + if pgdigit: + out += pack_uint2(pgdigit) + + return out + + +def dump_int_to_numeric_binary(obj: int) -> bytearray: + ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1 + out = bytearray(b"\x00\x00" * (ndigits + 4)) + if obj < 0: + sign = NUMERIC_NEG + obj = -obj + else: + sign = NUMERIC_POS + + out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0) + i = 8 + (ndigits - 1) * 2 + while obj: + rem = obj % 10_000 + obj //= 10_000 + out[i : i + 2] = pack_uint2(rem) + i -= 2 + + return out + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(int, IntDumper) + adapters.register_dumper(int, IntBinaryDumper) + adapters.register_dumper(float, FloatDumper) + adapters.register_dumper(float, FloatBinaryDumper) + adapters.register_dumper(Int2, Int2Dumper) + adapters.register_dumper(Int4, Int4Dumper) + adapters.register_dumper(Int8, Int8Dumper) + adapters.register_dumper(IntNumeric, IntNumericDumper) + adapters.register_dumper(Oid, OidDumper) + + # The binary dumper is currently some 30% slower, so default to text + # (see tests/scripts/testdec.py for a rough benchmark) + # Also, must be after IntNumericDumper + adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper) + adapters.register_dumper("decimal.Decimal", DecimalDumper) + + # Used only by oid, can take both int and Decimal as input + adapters.register_dumper(None, NumericBinaryDumper) + adapters.register_dumper(None, NumericDumper) + + adapters.register_dumper(Float4, Float4Dumper) + adapters.register_dumper(Float8, FloatDumper) + adapters.register_dumper(Int2, Int2BinaryDumper) + adapters.register_dumper(Int4, Int4BinaryDumper) + adapters.register_dumper(Int8, Int8BinaryDumper) + adapters.register_dumper(Oid, OidBinaryDumper) + adapters.register_dumper(Float4, Float4BinaryDumper) + adapters.register_dumper(Float8, FloatBinaryDumper) + adapters.register_loader("int2", IntLoader) + adapters.register_loader("int4", IntLoader) + adapters.register_loader("int8", IntLoader) + adapters.register_loader("oid", IntLoader) + adapters.register_loader("int2", Int2BinaryLoader) + adapters.register_loader("int4", Int4BinaryLoader) + adapters.register_loader("int8", Int8BinaryLoader) + adapters.register_loader("oid", OidBinaryLoader) + adapters.register_loader("float4", FloatLoader) + adapters.register_loader("float8", FloatLoader) + adapters.register_loader("float4", Float4BinaryLoader) + adapters.register_loader("float8", Float8BinaryLoader) + adapters.register_loader("numeric", NumericLoader) + adapters.register_loader("numeric", NumericBinaryLoader) diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py new file mode 100644 index 0000000..c418480 --- /dev/null +++ b/psycopg/psycopg/types/range.py @@ -0,0 +1,700 @@ +""" +Support for range types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple +from typing import cast +from decimal import Decimal +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import RangeInfo as RangeInfo # exported here + +RANGE_EMPTY = 0x01 # range is empty +RANGE_LB_INC = 0x02 # lower bound is inclusive +RANGE_UB_INC = 0x04 # upper bound is inclusive +RANGE_LB_INF = 0x08 # lower bound is -infinity +RANGE_UB_INF = 0x10 # upper bound is +infinity + +_EMPTY_HEAD = bytes([RANGE_EMPTY]) + +T = TypeVar("T") + + +class Range(Generic[T]): + """Python representation for a PostgreSQL range type. + + :param lower: lower bound for the range. `!None` means unbound + :param upper: upper bound for the range. `!None` means unbound + :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``, + representing whether the lower or upper bounds are included + :param empty: if `!True`, the range is empty + + """ + + __slots__ = ("_lower", "_upper", "_bounds") + + def __init__( + self, + lower: Optional[T] = None, + upper: Optional[T] = None, + bounds: str = "[)", + empty: bool = False, + ): + if not empty: + if bounds not in ("[)", "(]", "()", "[]"): + raise ValueError("bound flags not valid: %r" % bounds) + + self._lower = lower + self._upper = upper + + # Make bounds consistent with infs + if lower is None and bounds[0] == "[": + bounds = "(" + bounds[1] + if upper is None and bounds[1] == "]": + bounds = bounds[0] + ")" + + self._bounds = bounds + else: + self._lower = self._upper = None + self._bounds = "" + + def __repr__(self) -> str: + if self._bounds: + args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}" + else: + args = "empty=True" + + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + if not self._bounds: + return "empty" + + items = [ + self._bounds[0], + str(self._lower), + ", ", + str(self._upper), + self._bounds[1], + ] + return "".join(items) + + @property + def lower(self) -> Optional[T]: + """The lower bound of the range. `!None` if empty or unbound.""" + return self._lower + + @property + def upper(self) -> Optional[T]: + """The upper bound of the range. `!None` if empty or unbound.""" + return self._upper + + @property + def bounds(self) -> str: + """The bounds string (two characters from '[', '(', ']', ')').""" + return self._bounds + + @property + def isempty(self) -> bool: + """`!True` if the range is empty.""" + return not self._bounds + + @property + def lower_inf(self) -> bool: + """`!True` if the range doesn't have a lower bound.""" + if not self._bounds: + return False + return self._lower is None + + @property + def upper_inf(self) -> bool: + """`!True` if the range doesn't have an upper bound.""" + if not self._bounds: + return False + return self._upper is None + + @property + def lower_inc(self) -> bool: + """`!True` if the lower bound is included in the range.""" + if not self._bounds or self._lower is None: + return False + return self._bounds[0] == "[" + + @property + def upper_inc(self) -> bool: + """`!True` if the upper bound is included in the range.""" + if not self._bounds or self._upper is None: + return False + return self._bounds[1] == "]" + + def __contains__(self, x: T) -> bool: + if not self._bounds: + return False + + if self._lower is not None: + if self._bounds[0] == "[": + # It doesn't seem that Python has an ABC for ordered types. + if x < self._lower: # type: ignore[operator] + return False + else: + if x <= self._lower: # type: ignore[operator] + return False + + if self._upper is not None: + if self._bounds[1] == "]": + if x > self._upper: # type: ignore[operator] + return False + else: + if x >= self._upper: # type: ignore[operator] + return False + + return True + + def __bool__(self) -> bool: + return bool(self._bounds) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Range): + return False + return ( + self._lower == other._lower + and self._upper == other._upper + and self._bounds == other._bounds + ) + + def __hash__(self) -> int: + return hash((self._lower, self._upper, self._bounds)) + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Range): + return NotImplemented + for attr in ("_lower", "_upper", "_bounds"): + self_value = getattr(self, attr) + other_value = getattr(other, attr) + if self_value == other_value: + pass + elif self_value is None: + return True + elif other_value is None: + return False + else: + return cast(bool, self_value < other_value) + return False + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Range): + return other < self + else: + return NotImplemented + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + def __getstate__(self) -> Dict[str, Any]: + return { + slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot) + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + for slot, value in state.items(): + setattr(self, slot, value) + + +# Subclasses to specify a specific subtype. Usually not needed: only needed +# in binary copy, where switching to text is not an option. + + +class Int4Range(Range[int]): + pass + + +class Int8Range(Range[int]): + pass + + +class NumericRange(Range[Decimal]): + pass + + +class DateRange(Range[date]): + pass + + +class TimestampRange(Range[datetime]): + pass + + +class TimestamptzRange(Range[datetime]): + pass + + +class BaseRangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self + + item = self._get_item(obj) + if item is None: + return RangeDumper(self.cls) + + dumper: BaseRangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = RangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_range_oid(TEXT_OID) + else: + dumper.oid = self._get_range_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Range[Any]) -> Any: + """ + Return a member representative of the range + """ + rv = obj.lower + return rv if rv is not None else obj.upper + + def _get_range_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class RangeDumper(BaseRangeDumper): + """ + Dumper for range types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_text(obj, dump) + + +def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if obj.isempty: + return b"empty" + + parts: List[Buffer] = [b"[" if obj.lower_inc else b"("] + + def dump_item(item: Any) -> Buffer: + ad = dump(item) + if not ad: + return b'""' + elif _re_needs_quotes.search(ad): + return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"' + else: + return ad + + if obj.lower is not None: + parts.append(dump_item(obj.lower)) + + parts.append(b",") + + if obj.upper is not None: + parts.append(dump_item(obj.upper)) + + parts.append(b"]" if obj.upper_inc else b")") + + return b"".join(parts) + + +_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]') +_re_esc = re.compile(rb"([\\\"])") + + +class RangeBinaryDumper(BaseRangeDumper): + + format = Format.BINARY + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_binary(obj, dump) + + +def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if not obj: + return _EMPTY_HEAD + + out = bytearray([0]) # will replace the head later + + head = 0 + if obj.lower_inc: + head |= RANGE_LB_INC + if obj.upper_inc: + head |= RANGE_UB_INC + + if obj.lower is not None: + data = dump(obj.lower) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_LB_INF + + if obj.upper is not None: + data = dump(obj.upper) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_UB_INF + + out[0] = head + return out + + +def fail_dump(obj: Any) -> Buffer: + raise e.InternalError("trying to dump a range element without information") + + +class BaseRangeLoader(RecursiveLoader, Generic[T]): + """Generic loader for a range. + + Subclasses must specify the oid of the subtype and the class to load. + """ + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class RangeLoader(BaseRangeLoader[T]): + def load(self, data: Buffer) -> Range[T]: + return load_range_text(data, self._load)[0] + + +def load_range_text( + data: Buffer, load: Callable[[Buffer], Any] +) -> Tuple[Range[Any], int]: + if data == b"empty": + return Range(empty=True), 5 + + m = _re_range.match(data) + if m is None: + raise e.DataError( + f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'" + ) + + lower = None + item = m.group(3) + if item is None: + item = m.group(2) + if item is not None: + lower = load(_re_undouble.sub(rb"\1", item)) + else: + lower = load(item) + + upper = None + item = m.group(5) + if item is None: + item = m.group(4) + if item is not None: + upper = load(_re_undouble.sub(rb"\1", item)) + else: + upper = load(item) + + bounds = (m.group(1) + m.group(6)).decode() + + return Range(lower, upper, bounds), m.end() + + +_re_range = re.compile( + rb""" + ( \(|\[ ) # lower bound flag + (?: # lower bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^",]+ ) # - or an unquoted string + )? # - or empty (not caught) + , + (?: # upper bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^"\)\]]+ ) # - or an unquoted string + )? # - or empty (not caught) + ( \)|\] ) # upper bound flag + """, + re.VERBOSE, +) + +_re_undouble = re.compile(rb'(["\\])\1') + + +class RangeBinaryLoader(BaseRangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Range[T]: + return load_range_binary(data, self._load) + + +def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]: + head = data[0] + if head & RANGE_EMPTY: + return Range(empty=True) + + lb = "[" if head & RANGE_LB_INC else "(" + ub = "]" if head & RANGE_UB_INC else ")" + + pos = 1 # after the head + if head & RANGE_LB_INF: + min = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + min = load(data[pos : pos + length]) + pos += length + + if head & RANGE_UB_INF: + max = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + max = load(data[pos : pos + length]) + pos += length + + return Range(min, max, lb + ub) + + +def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump a range type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested range available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[RangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (RangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[RangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (RangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeDumper(RangeDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeDumper(RangeDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeDumper(RangeDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeDumper(RangeDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeDumper(RangeDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeDumper(RangeDumper): + oid = postgres.types["tstzrange"].oid + + +# Binary dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tstzrange"].oid + + +# Text loaders for builtin range types + + +class Int4RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeLoader(RangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeLoader(RangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin range types + + +class Int4RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeBinaryLoader(RangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Range, RangeBinaryDumper) + adapters.register_dumper(Range, RangeDumper) + adapters.register_dumper(Int4Range, Int4RangeDumper) + adapters.register_dumper(Int8Range, Int8RangeDumper) + adapters.register_dumper(NumericRange, NumericRangeDumper) + adapters.register_dumper(DateRange, DateRangeDumper) + adapters.register_dumper(TimestampRange, TimestampRangeDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper) + adapters.register_dumper(Int4Range, Int4RangeBinaryDumper) + adapters.register_dumper(Int8Range, Int8RangeBinaryDumper) + adapters.register_dumper(NumericRange, NumericRangeBinaryDumper) + adapters.register_dumper(DateRange, DateRangeBinaryDumper) + adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper) + adapters.register_loader("int4range", Int4RangeLoader) + adapters.register_loader("int8range", Int8RangeLoader) + adapters.register_loader("numrange", NumericRangeLoader) + adapters.register_loader("daterange", DateRangeLoader) + adapters.register_loader("tsrange", TimestampRangeLoader) + adapters.register_loader("tstzrange", TimestampTZRangeLoader) + adapters.register_loader("int4range", Int4RangeBinaryLoader) + adapters.register_loader("int8range", Int8RangeBinaryLoader) + adapters.register_loader("numrange", NumericRangeBinaryLoader) + adapters.register_loader("daterange", DateRangeBinaryLoader) + adapters.register_loader("tsrange", TimestampRangeBinaryLoader) + adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader) diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py new file mode 100644 index 0000000..e99f256 --- /dev/null +++ b/psycopg/psycopg/types/shapely.py @@ -0,0 +1,75 @@ +""" +Adapters for PostGIS geometries +""" + +from typing import Optional + +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Dumper, Loader +from ..pq import Format +from .._typeinfo import TypeInfo + + +try: + from shapely.wkb import loads, dumps + from shapely.geometry.base import BaseGeometry + +except ImportError: + raise ImportError( + "The module psycopg.types.shapely requires the package 'Shapely'" + " to be installed" + ) + + +class GeometryBinaryLoader(Loader): + format = Format.BINARY + + def load(self, data: Buffer) -> "BaseGeometry": + if not isinstance(data, bytes): + data = bytes(data) + return loads(data) + + +class GeometryLoader(Loader): + def load(self, data: Buffer) -> "BaseGeometry": + # it's a hex string in binary + if isinstance(data, memoryview): + data = bytes(data) + return loads(data.decode(), hex=True) + + +class BaseGeometryBinaryDumper(Dumper): + format = Format.BINARY + + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj) # type: ignore + + +class BaseGeometryDumper(Dumper): + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj, hex=True).encode() # type: ignore + + +def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register Shapely dumper and loaders.""" + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'postgis' extension loaded?") + + info.register(context) + adapters = context.adapters if context else postgres.adapters + + class GeometryDumper(BaseGeometryDumper): + oid = info.oid + + class GeometryBinaryDumper(BaseGeometryBinaryDumper): + oid = info.oid + + adapters.register_loader(info.oid, GeometryBinaryLoader) + adapters.register_loader(info.oid, GeometryLoader) + # Default binary dump + adapters.register_dumper(BaseGeometry, GeometryDumper) + adapters.register_dumper(BaseGeometry, GeometryBinaryDumper) diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py new file mode 100644 index 0000000..cd5360d --- /dev/null +++ b/psycopg/psycopg/types/string.py @@ -0,0 +1,239 @@ +""" +Adapters for textual types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Optional, Union, TYPE_CHECKING + +from .. import postgres +from ..pq import Format, Escaping +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from ..errors import DataError +from .._encodings import conn_encoding + +if TYPE_CHECKING: + from ..pq.abc import Escaping as EscapingProto + + +class _BaseStrDumper(Dumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "utf-8" + + +class _StrBinaryDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in binary format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + format = Format.BINARY + + def dump(self, obj: str) -> bytes: + # the server will raise DataError subclass if the string contains 0x00 + return obj.encode(self._encoding) + + +class _StrDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in text format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + def dump(self, obj: str) -> bytes: + if "\x00" in obj: + raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes") + else: + return obj.encode(self._encoding) + + +# The next are concrete dumpers, each one specifying the oid they dump to. + + +class StrBinaryDumper(_StrBinaryDumper): + + oid = postgres.types["text"].oid + + +class StrBinaryDumperVarchar(_StrBinaryDumper): + + oid = postgres.types["varchar"].oid + + +class StrBinaryDumperName(_StrBinaryDumper): + + oid = postgres.types["name"].oid + + +class StrDumper(_StrDumper): + """ + Dumper for strings in text format to the text oid. + + Note that this dumper is not used by default because the type is too strict + and PostgreSQL would require an explicit casts to everything that is not a + text field. However it is useful where the unknown oid is ambiguous and the + text oid is required, for instance with variadic functions. + """ + + oid = postgres.types["text"].oid + + +class StrDumperVarchar(_StrDumper): + + oid = postgres.types["varchar"].oid + + +class StrDumperName(_StrDumper): + + oid = postgres.types["name"].oid + + +class StrDumperUnknown(_StrDumper): + """ + Dumper for strings in text format to the unknown oid. + + This dumper is the default dumper for strings and allows to use Python + strings to represent almost every data type. In a few places, however, the + unknown oid is not accepted (for instance in variadic functions such as + 'concat()'). In that case either a cast on the placeholder ('%s::text') or + the StrTextDumper should be used. + """ + + pass + + +class TextLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "" + + def load(self, data: Buffer) -> Union[bytes, str]: + if self._encoding: + if isinstance(data, memoryview): + data = bytes(data) + return data.decode(self._encoding) + else: + # return bytes for SQL_ASCII db + if not isinstance(data, bytes): + data = bytes(data) + return data + + +class TextBinaryLoader(TextLoader): + + format = Format.BINARY + + +class BytesDumper(Dumper): + + oid = postgres.types["bytea"].oid + _qprefix = b"" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._esc = Escaping(self.connection.pgconn if self.connection else None) + + def dump(self, obj: Buffer) -> Buffer: + return self._esc.escape_bytea(obj) + + def quote(self, obj: Buffer) -> bytes: + escaped = self.dump(obj) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self.connection: + if not self._qprefix: + scs = self.connection.pgconn.parameter_status( + b"standard_conforming_strings" + ) + self._qprefix = b"'" if scs == b"on" else b" E'" + + return self._qprefix + escaped + b"'" + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + rv: bytes = b" E'" + escaped + b"'" + if self._esc.escape_bytea(b"\x00") == b"\\000": + rv = rv.replace(b"\\", b"\\\\") + return rv + + +class BytesBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bytea"].oid + + def dump(self, obj: Buffer) -> Buffer: + return obj + + +class ByteaLoader(Loader): + + _escaping: "EscapingProto" + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if not hasattr(self.__class__, "_escaping"): + self.__class__._escaping = Escaping() + + def load(self, data: Buffer) -> bytes: + return self._escaping.unescape_bytea(data) + + +class ByteaBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Buffer: + return data + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + + # NOTE: the order the dumpers are registered is relevant. The last one + # registered becomes the default for each type. Usually, binary is the + # default dumper. For text we use the text dumper as default because it + # plays the role of unknown, and it can be cast automatically to other + # types. However, before that, we register dumper with 'text', 'varchar', + # 'name' oids, which will be used when a text dumper is looked up by oid. + adapters.register_dumper(str, StrBinaryDumperName) + adapters.register_dumper(str, StrBinaryDumperVarchar) + adapters.register_dumper(str, StrBinaryDumper) + adapters.register_dumper(str, StrDumperName) + adapters.register_dumper(str, StrDumperVarchar) + adapters.register_dumper(str, StrDumper) + adapters.register_dumper(str, StrDumperUnknown) + + adapters.register_loader(postgres.INVALID_OID, TextLoader) + adapters.register_loader("bpchar", TextLoader) + adapters.register_loader("name", TextLoader) + adapters.register_loader("text", TextLoader) + adapters.register_loader("varchar", TextLoader) + adapters.register_loader('"char"', TextLoader) + adapters.register_loader("bpchar", TextBinaryLoader) + adapters.register_loader("name", TextBinaryLoader) + adapters.register_loader("text", TextBinaryLoader) + adapters.register_loader("varchar", TextBinaryLoader) + adapters.register_loader('"char"', TextBinaryLoader) + + adapters.register_dumper(bytes, BytesDumper) + adapters.register_dumper(bytearray, BytesDumper) + adapters.register_dumper(memoryview, BytesDumper) + adapters.register_dumper(bytes, BytesBinaryDumper) + adapters.register_dumper(bytearray, BytesBinaryDumper) + adapters.register_dumper(memoryview, BytesBinaryDumper) + + adapters.register_loader("bytea", ByteaLoader) + adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader) + adapters.register_loader("bytea", ByteaBinaryLoader) diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py new file mode 100644 index 0000000..f92354c --- /dev/null +++ b/psycopg/psycopg/types/uuid.py @@ -0,0 +1,65 @@ +""" +Adapters for the UUID type. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import uuid + +# Importing the uuid module is slow, so import it only on request. +UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment] + + +class UUIDDumper(Dumper): + + oid = postgres.types["uuid"].oid + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.hex.encode() + + +class UUIDBinaryDumper(UUIDDumper): + + format = Format.BINARY + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.bytes + + +class UUIDLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + global UUID + if UUID is None: + from uuid import UUID + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(data.decode()) + + +class UUIDBinaryLoader(UUIDLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(bytes=data) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("uuid.UUID", UUIDDumper) + adapters.register_dumper("uuid.UUID", UUIDBinaryDumper) + adapters.register_loader("uuid", UUIDLoader) + adapters.register_loader("uuid", UUIDBinaryLoader) diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py new file mode 100644 index 0000000..a98bc35 --- /dev/null +++ b/psycopg/psycopg/version.py @@ -0,0 +1,14 @@ +""" +psycopg distribution version file. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ + +# STOP AND READ! if you change: +__version__ = "3.1.7" +# also change: +# - `docs/news.rst` to declare this as the current version or an unreleased one +# - `psycopg_c/psycopg_c/version.py` to the same version. diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py new file mode 100644 index 0000000..7abfc58 --- /dev/null +++ b/psycopg/psycopg/waiting.py @@ -0,0 +1,331 @@ +""" +Code concerned with waiting in different contexts (blocking, async, etc). + +These functions are designed to consume the generators returned by the +`generators` module function and to return their final value. + +""" + +# Copyright (C) 2020 The Psycopg Team + + +import os +import select +import selectors +from typing import Dict, Optional +from asyncio import get_event_loop, wait_for, Event, TimeoutError +from selectors import DefaultSelector + +from . import errors as e +from .abc import RV, PQGen, PQGenConn, WaitFunc +from ._enums import Wait as Wait, Ready as Ready # re-exported +from ._cmodule import _psycopg + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + + +def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Consume `!gen`, scheduling `fileno` for completion when it is reported to + block. Once ready again send the ready state back to `!gen`. + """ + try: + s = next(gen) + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = None + while not rlist: + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + # note: this line should require a cast, but mypy doesn't complain + ready: Ready = rlist[0][1] + assert s & ready + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Wait for a connection generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + try: + fileno, s = next(gen) + if not timeout: + timeout = None + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + if not rlist: + raise e.ConnectionTimeout("connection timeout expired") + ready: Ready = rlist[0][1] # type: ignore[assignment] + fileno, s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_async(gen: PQGen[RV], fileno: int) -> RV: + """ + Coroutine waiting for a generator to complete. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but exposing an `asyncio` interface. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready |= state # type: ignore[assignment] + ev.set() + + try: + s = next(gen) + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await ev.wait() + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Coroutine waiting for a connection generator to complete. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready = state + ev.set() + + try: + fileno, s = next(gen) + if not timeout: + timeout = None + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await wait_for(ev.wait(), timeout) + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + fileno, s = gen.send(ready) + + except TimeoutError: + raise e.ConnectionTimeout("connection timeout expired") + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +# Specialised implementation of wait functions. + + +def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using select where supported. + """ + try: + s = next(gen) + + empty = () + fnlist = (fileno,) + while True: + rl, wl, xl = select.select( + fnlist if s & WAIT_R else empty, + fnlist if s & WAIT_W else empty, + fnlist, + timeout, + ) + ready = 0 + if rl: + ready = READY_R + if wl: + ready |= READY_W + if not ready: + continue + # assert s & ready + s = gen.send(ready) # type: ignore + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +poll_evmasks: Dict[Wait, int] + +if hasattr(selectors, "EpollSelector"): + poll_evmasks = { + WAIT_R: select.EPOLLONESHOT | select.EPOLLIN, + WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT, + WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT, + } +else: + poll_evmasks = {} + + +def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using epoll where supported. + + Parameters are like for `wait()`. If it is detected that the best selector + strategy is `epoll` then this function will be used instead of `wait`. + + See also: https://linux.die.net/man/2/epoll_ctl + """ + try: + s = next(gen) + + if timeout is None or timeout < 0: + timeout = 0 + else: + timeout = int(timeout * 1000.0) + + with select.epoll() as epoll: + evmask = poll_evmasks[s] + epoll.register(fileno, evmask) + while True: + fileevs = None + while not fileevs: + fileevs = epoll.poll(timeout) + ev = fileevs[0][1] + ready = 0 + if ev & ~select.EPOLLOUT: + ready = READY_R + if ev & ~select.EPOLLIN: + ready |= READY_W + # assert s & ready + s = gen.send(ready) + evmask = poll_evmasks[s] + epoll.modify(fileno, evmask) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +if _psycopg: + wait_c = _psycopg.wait_c + + +# Choose the best wait strategy for the platform. +# +# the selectors objects have a generic interface but come with some overhead, +# so we also offer more finely tuned implementations. + +wait: WaitFunc + +# Allow the user to choose a specific function for testing +if "PSYCOPG_WAIT_FUNC" in os.environ: + fname = os.environ["PSYCOPG_WAIT_FUNC"] + if not fname.startswith("wait_") or fname not in globals(): + raise ImportError( + "PSYCOPG_WAIT_FUNC should be the name of an available wait function;" + f" got {fname!r}" + ) + wait = globals()[fname] + +elif _psycopg: + wait = wait_c + +elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None): + # On Windows, SelectSelector should be the default. + wait = wait_select + +elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None): + # NOTE: select seems more performing than epoll. It is admittedly unlikely + # that a platform has epoll but not select, so maybe we could kill + # wait_epoll altogether(). More testing to do. + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll + +elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None): + # wait_select is faster than wait_selector, probably because of less overhead + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector + +else: + wait = wait_selector diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml new file mode 100644 index 0000000..21e410c --- /dev/null +++ b/psycopg/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37"] +build-backend = "setuptools.build_meta" diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg new file mode 100644 index 0000000..fdcb612 --- /dev/null +++ b/psycopg/setup.cfg @@ -0,0 +1,47 @@ +[metadata] +name = psycopg +description = PostgreSQL database adapter for Python +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +packages = find: +zip_safe = False +install_requires = + backports.zoneinfo >= 0.2.0; python_version < "3.9" + typing-extensions >= 4.1 + tzdata; sys_platform == "win32" + +[options.package_data] +psycopg = py.typed diff --git a/psycopg/setup.py b/psycopg/setup.py new file mode 100644 index 0000000..90d4380 --- /dev/null +++ b/psycopg/setup.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - pure Python package +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +from setuptools import setup + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +# Only for release 3.1.7. Not building binary packages because Scaleway +# has no runner available, but psycopg-binary 3.1.6 should work as well +# as the only change is in rows.py. +version = "3.1.7" +ext_versions = ">= 3.1.6, <= 3.1.7" + +extras_require = { + # Install the C extension module (requires dev tools) + "c": [ + f"psycopg-c {ext_versions}", + ], + # Install the stand-alone C extension module + "binary": [ + f"psycopg-binary {ext_versions}", + ], + # Install the connection pool + "pool": [ + "psycopg-pool", + ], + # Requirements to run the test suite + "test": [ + "mypy >= 0.990", + "pproxy >= 2.7", + "pytest >= 6.2.5", + "pytest-asyncio >= 0.17", + "pytest-cov >= 3.0", + "pytest-randomly >= 3.10", + ], + # Requirements needed for development + "dev": [ + "black >= 22.3.0", + "dnspython >= 2.1", + "flake8 >= 4.0", + "mypy >= 0.990", + "types-setuptools >= 57.4", + "wheel >= 0.37", + ], + # Requirements needed to build the documentation + "docs": [ + "Sphinx >= 5.0", + "furo == 2022.6.21", + "sphinx-autobuild >= 2021.3.14", + "sphinx-autodoc-typehints >= 1.12", + ], +} + +setup( + version=version, + extras_require=extras_require, +) diff --git a/psycopg_c/.flake8 b/psycopg_c/.flake8 new file mode 100644 index 0000000..2ae629c --- /dev/null +++ b/psycopg_c/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 diff --git a/psycopg_c/LICENSE.txt b/psycopg_c/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg_c/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg_c/README-binary.rst b/psycopg_c/README-binary.rst new file mode 100644 index 0000000..9318d57 --- /dev/null +++ b/psycopg_c/README-binary.rst @@ -0,0 +1,29 @@ +Psycopg 3: PostgreSQL database adapter for Python - binary package +================================================================== + +This distribution contains the precompiled optimization package +``psycopg_binary``. + +You shouldn't install this package directly: use instead :: + + pip install "psycopg[binary]" + +to install a version of the optimization package matching the ``psycopg`` +version installed. + +Installing this package requires pip >= 20.3 or newer installed. + +This package is not available for every platform: check out `Binary +installation`__ in the documentation. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #binary-installation + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_c/README.rst b/psycopg_c/README.rst new file mode 100644 index 0000000..de9ba93 --- /dev/null +++ b/psycopg_c/README.rst @@ -0,0 +1,33 @@ +Psycopg 3: PostgreSQL database adapter for Python - optimisation package +======================================================================== + +This distribution contains the optional optimization package ``psycopg_c``. + +You shouldn't install this package directly: use instead :: + + pip install "psycopg[c]" + +to install a version of the optimization package matching the ``psycopg`` +version installed. + +Installing this package requires some prerequisites: check `Local +installation`__ in the documentation. Without a C compiler and some library +headers install *will fail*: this is not a bug. + +If you are unable to meet the prerequisite needed you might want to install +``psycopg[binary]`` instead: look for `Binary installation`__ in the +documentation. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #local-installation +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #binary-installation + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_c/psycopg_c/.gitignore b/psycopg_c/psycopg_c/.gitignore new file mode 100644 index 0000000..36edb64 --- /dev/null +++ b/psycopg_c/psycopg_c/.gitignore @@ -0,0 +1,4 @@ +/*.so +_psycopg.c +pq.c +*.html diff --git a/psycopg_c/psycopg_c/__init__.py b/psycopg_c/psycopg_c/__init__.py new file mode 100644 index 0000000..14db92b --- /dev/null +++ b/psycopg_c/psycopg_c/__init__.py @@ -0,0 +1,14 @@ +""" +psycopg -- PostgreSQL database adapter for Python -- C optimization package +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys + +# This package shouldn't be imported before psycopg itself, or weird things +# will happen +if "psycopg" not in sys.modules: + raise ImportError("the psycopg package should be imported before psycopg_c") + +from .version import __version__ as __version__ # noqa diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi new file mode 100644 index 0000000..bd7c63d --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -0,0 +1,84 @@ +""" +Stub representaton of the public objects exposed by the _psycopg module. + +TODO: this should be generated by mypy's stubgen but it crashes with no +information. Will submit a bug. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Iterable, List, Optional, Sequence, Tuple + +from psycopg import pq +from psycopg import abc +from psycopg.rows import Row, RowMaker +from psycopg.adapt import AdaptersMap, PyFormat +from psycopg.pq.abc import PGconn, PGresult +from psycopg.connection import BaseConnection +from psycopg._compat import Deque + +class Transformer(abc.AdaptContext): + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + def __init__(self, context: Optional[abc.AdaptContext] = None): ... + @classmethod + def from_context(cls, context: Optional[abc.AdaptContext]) -> "Transformer": ... + @property + def connection(self) -> Optional[BaseConnection[Any]]: ... + @property + def encoding(self) -> str: ... + @property + def adapters(self) -> AdaptersMap: ... + @property + def pgresult(self) -> Optional[PGresult]: ... + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None, + ) -> None: ... + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ... + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ... + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[abc.Buffer]]: ... + def as_literal(self, obj: Any) -> bytes: ... + def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... + def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ... + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ... + def load_sequence( + self, record: Sequence[Optional[abc.Buffer]] + ) -> Tuple[Any, ...]: ... + def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... + +# Generators +def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ... +def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... +def send(pgconn: PGconn) -> abc.PQGen[None]: ... +def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... +def fetch(pgconn: PGconn) -> abc.PQGen[Optional[PGresult]]: ... +def pipeline_communicate( + pgconn: PGconn, commands: Deque[abc.PipelineCommand] +) -> abc.PQGen[List[List[PGresult]]]: ... +def wait_c( + gen: abc.PQGen[abc.RV], fileno: int, timeout: Optional[float] = None +) -> abc.RV: ... + +# Copy support +def format_row_text( + row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None +) -> bytearray: ... +def format_row_binary( + row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None +) -> bytearray: ... +def parse_row_text(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ... +def parse_row_binary(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ... + +# Arrays optimization +def array_load_text( + data: abc.Buffer, loader: abc.Loader, delimiter: bytes = b"," +) -> List[Any]: ... +def array_load_binary(data: abc.Buffer, tx: abc.Transformer) -> List[Any]: ... + +# vim: set syntax=python: diff --git a/psycopg_c/psycopg_c/_psycopg.pyx b/psycopg_c/psycopg_c/_psycopg.pyx new file mode 100644 index 0000000..9d2b8ba --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg.pyx @@ -0,0 +1,48 @@ +""" +psycopg_c._psycopg optimization module. + +The module contains optimized C code used in preference to Python code +if a compiler is available. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c cimport pq +from psycopg_c.pq cimport libpq +from psycopg_c._psycopg cimport oids + +import logging + +from psycopg.pq import Format as _pq_Format +from psycopg._enums import PyFormat as _py_Format + +logger = logging.getLogger("psycopg") + +PQ_TEXT = _pq_Format.TEXT +PQ_BINARY = _pq_Format.BINARY + +PG_AUTO = _py_Format.AUTO +PG_TEXT = _py_Format.TEXT +PG_BINARY = _py_Format.BINARY + + +cdef extern from *: + """ +#ifndef ARRAYSIZE +#define ARRAYSIZE(a) ((sizeof(a) / sizeof(*(a)))) +#endif + """ + int ARRAYSIZE(void *array) + + +include "_psycopg/adapt.pyx" +include "_psycopg/copy.pyx" +include "_psycopg/generators.pyx" +include "_psycopg/transform.pyx" +include "_psycopg/waiting.pyx" + +include "types/array.pyx" +include "types/datetime.pyx" +include "types/numeric.pyx" +include "types/bool.pyx" +include "types/string.pyx" diff --git a/psycopg_c/psycopg_c/_psycopg/__init__.pxd b/psycopg_c/psycopg_c/_psycopg/__init__.pxd new file mode 100644 index 0000000..db22deb --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/__init__.pxd @@ -0,0 +1,9 @@ +""" +psycopg_c._psycopg cython module. + +This file is necessary to allow c-importing pxd files from this directory. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c._psycopg cimport oids diff --git a/psycopg_c/psycopg_c/_psycopg/adapt.pyx b/psycopg_c/psycopg_c/_psycopg/adapt.pyx new file mode 100644 index 0000000..a6d8e6a --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/adapt.pyx @@ -0,0 +1,171 @@ +""" +C implementation of the adaptation system. + +This module maps each Python adaptation function to a C adaptation function. +Notice that C adaptation functions have a different signature because they can +avoid making a memory copy, however this makes impossible to expose them to +Python. + +This module exposes facilities to map the builtin adapters in python to +equivalent C implementations. + +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any + +cimport cython + +from libc.string cimport memcpy, memchr +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING + +from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping + +from psycopg import errors as e +from psycopg.pq.misc import error_message + + +@cython.freelist(8) +cdef class CDumper: + + cdef readonly object cls + cdef pq.PGconn _pgconn + + oid = oids.INVALID_OID + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + self.cls = cls + conn = context.connection if context is not None else None + self._pgconn = conn.pgconn if conn is not None else None + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + """Store the Postgres representation *obj* into *rv* at *offset* + + Return the number of bytes written to rv or -1 on Python exception. + + Subclasses must implement this method. The `dump()` implementation + transforms the result of this method to a bytearray so that it can be + returned to Python. + + The function interface allows C code to use this method automatically + to create larger buffers, e.g. for copy, composite objects, etc. + + Implementation note: as you will always need to make sure that rv + has enough space to include what you want to dump, `ensure_size()` + might probably come handy. + """ + raise NotImplementedError() + + def dump(self, obj): + """Return the Postgres representation of *obj* as Python array of bytes""" + cdef rv = PyByteArray_FromStringAndSize("", 0) + cdef Py_ssize_t length = self.cdump(obj, rv, 0) + PyByteArray_Resize(rv, length) + return rv + + def quote(self, obj): + cdef char *ptr + cdef char *ptr_out + cdef Py_ssize_t length + + value = self.dump(obj) + + if self._pgconn is not None: + esc = Escaping(self._pgconn) + # escaping and quoting + return esc.escape_literal(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + rv = PyByteArray_FromStringAndSize("", 0) + + # No quoting, only quote escaping, random bs escaping. See further. + esc = Escaping() + out = esc.escape_string(value) + + _buffer_as_string_and_size(out, &ptr, &length) + + if not memchr(ptr, b'\\', length): + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + PyByteArray_Resize(rv, length + 2) # Must include the quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b"'" + memcpy(ptr_out + 1, ptr, length) + ptr_out[length + 1] = b"'" + return rv + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + if esc.escape_string(b"\\") == b"\\": + rv = bytes(rv).replace(b"\\", b"\\\\") + return rv + + cpdef object get_key(self, object obj, object format): + return self.cls + + cpdef object upgrade(self, object obj, object format): + return self + + @staticmethod + cdef char *ensure_size(bytearray ba, Py_ssize_t offset, Py_ssize_t size) except NULL: + """ + Grow *ba*, if necessary, to contains at least *size* bytes after *offset* + + Return the pointer in the bytearray at *offset*, i.e. the place where + you want to write *size* bytes. + """ + cdef Py_ssize_t curr_size = PyByteArray_GET_SIZE(ba) + cdef Py_ssize_t new_size = offset + size + if curr_size < new_size: + PyByteArray_Resize(ba, new_size) + + return PyByteArray_AS_STRING(ba) + offset + + +@cython.freelist(8) +cdef class CLoader: + cdef public libpq.Oid oid + cdef pq.PGconn _pgconn + + def __cinit__(self, int oid, context: Optional[AdaptContext] = None): + self.oid = oid + conn = context.connection if context is not None else None + self._pgconn = conn.pgconn if conn is not None else None + + cdef object cload(self, const char *data, size_t length): + raise NotImplementedError() + + def load(self, object data) -> Any: + cdef char *ptr + cdef Py_ssize_t length + _buffer_as_string_and_size(data, &ptr, &length) + return self.cload(ptr, length) + + +cdef class _CRecursiveLoader(CLoader): + + cdef Transformer _tx + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + self._tx = Transformer.from_context(context) diff --git a/psycopg_c/psycopg_c/_psycopg/copy.pyx b/psycopg_c/psycopg_c/_psycopg/copy.pyx new file mode 100644 index 0000000..b943095 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/copy.pyx @@ -0,0 +1,340 @@ +""" +C optimised functions for the copy system. + +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.string cimport memcpy +from libc.stdint cimport uint16_t, uint32_t, int32_t +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE +from cpython.memoryview cimport PyMemoryView_FromObject + +from psycopg_c._psycopg cimport endian +from psycopg_c.pq cimport ViewBuffer + +from psycopg import errors as e + +cdef int32_t _binary_null = -1 + + +def format_row_binary( + row: Sequence[Any], tx: Transformer, out: bytearray = None +) -> bytearray: + """Convert a row of adapted data to the data to send for binary copy""" + cdef Py_ssize_t rowlen = len(row) + cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen) + + cdef Py_ssize_t pos # offset in 'out' where to write + if out is None: + out = PyByteArray_FromStringAndSize("", 0) + pos = 0 + else: + pos = PyByteArray_GET_SIZE(out) + + # let's start from a nice chunk + # (larger than most fixed size; for variable ones, oh well, we'll resize it) + cdef char *target = CDumper.ensure_size( + out, pos, sizeof(berowlen) + 20 * rowlen) + + # Write the number of fields as network-order 16 bits + memcpy(target, <void *>&berowlen, sizeof(berowlen)) + pos += sizeof(berowlen) + + cdef Py_ssize_t size + cdef uint32_t besize + cdef char *buf + cdef int i + cdef PyObject *fmt = <PyObject *>PG_BINARY + cdef PyObject *row_dumper + + if not tx._row_dumpers: + tx._row_dumpers = PyList_New(rowlen) + + dumpers = tx._row_dumpers + + for i in range(rowlen): + item = row[i] + if item is None: + target = CDumper.ensure_size(out, pos, sizeof(_binary_null)) + memcpy(target, <void *>&_binary_null, sizeof(_binary_null)) + pos += sizeof(_binary_null) + continue + + row_dumper = PyList_GET_ITEM(dumpers, i) + if not row_dumper: + row_dumper = tx.get_row_dumper(<PyObject *>item, fmt) + Py_INCREF(<object>row_dumper) + PyList_SET_ITEM(dumpers, i, <object>row_dumper) + + if (<RowDumper>row_dumper).cdumper is not None: + # A cdumper can resize if necessary and copy in place + size = (<RowDumper>row_dumper).cdumper.cdump( + item, out, pos + sizeof(besize)) + # Also add the size of the item, before the item + besize = endian.htobe32(<int32_t>size) + target = PyByteArray_AS_STRING(out) # might have been moved by cdump + memcpy(target + pos, <void *>&besize, sizeof(besize)) + else: + # A Python dumper, gotta call it and extract its juices + b = PyObject_CallFunctionObjArgs( + (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL) + _buffer_as_string_and_size(b, &buf, &size) + target = CDumper.ensure_size(out, pos, size + sizeof(besize)) + besize = endian.htobe32(<int32_t>size) + memcpy(target, <void *>&besize, sizeof(besize)) + memcpy(target + sizeof(besize), buf, size) + + pos += size + sizeof(besize) + + # Resize to the final size + PyByteArray_Resize(out, pos) + return out + + +def format_row_text( + row: Sequence[Any], tx: Transformer, out: bytearray = None +) -> bytearray: + cdef Py_ssize_t pos # offset in 'out' where to write + if out is None: + out = PyByteArray_FromStringAndSize("", 0) + pos = 0 + else: + pos = PyByteArray_GET_SIZE(out) + + cdef Py_ssize_t rowlen = len(row) + + if rowlen == 0: + PyByteArray_Resize(out, pos + 1) + out[pos] = b"\n" + return out + + cdef Py_ssize_t size, tmpsize + cdef char *buf + cdef int i, j + cdef unsigned char *target + cdef int nesc = 0 + cdef int with_tab + cdef PyObject *fmt = <PyObject *>PG_TEXT + cdef PyObject *row_dumper + + for i in range(rowlen): + # Include the tab before the data, so it gets included in the resizes + with_tab = i > 0 + + item = row[i] + if item is None: + if with_tab: + target = <unsigned char *>CDumper.ensure_size(out, pos, 3) + memcpy(target, b"\t\\N", 3) + pos += 3 + else: + target = <unsigned char *>CDumper.ensure_size(out, pos, 2) + memcpy(target, b"\\N", 2) + pos += 2 + continue + + row_dumper = tx.get_row_dumper(<PyObject *>item, fmt) + if (<RowDumper>row_dumper).cdumper is not None: + # A cdumper can resize if necessary and copy in place + size = (<RowDumper>row_dumper).cdumper.cdump( + item, out, pos + with_tab) + target = <unsigned char *>PyByteArray_AS_STRING(out) + pos + else: + # A Python dumper, gotta call it and extract its juices + b = PyObject_CallFunctionObjArgs( + (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL) + _buffer_as_string_and_size(b, &buf, &size) + target = <unsigned char *>CDumper.ensure_size(out, pos, size + with_tab) + memcpy(target + with_tab, buf, size) + + # Prepend a tab to the data just written + if with_tab: + target[0] = b"\t" + target += 1 + pos += 1 + + # Now from pos to pos + size there is a textual representation: it may + # contain chars to escape. Scan to find how many such chars there are. + for j in range(size): + if copy_escape_lut[target[j]]: + nesc += 1 + + # If there is any char to escape, walk backwards pushing the chars + # forward and interspersing backslashes. + if nesc > 0: + tmpsize = size + nesc + target = <unsigned char *>CDumper.ensure_size(out, pos, tmpsize) + for j in range(<int>size - 1, -1, -1): + if copy_escape_lut[target[j]]: + target[j + nesc] = copy_escape_lut[target[j]] + nesc -= 1 + target[j + nesc] = b"\\" + if nesc <= 0: + break + else: + target[j + nesc] = target[j] + pos += tmpsize + else: + pos += size + + # Resize to the final size, add the newline + PyByteArray_Resize(out, pos + 1) + out[pos] = b"\n" + return out + + +def parse_row_binary(data, tx: Transformer) -> Tuple[Any, ...]: + cdef unsigned char *ptr + cdef Py_ssize_t bufsize + _buffer_as_string_and_size(data, <char **>&ptr, &bufsize) + cdef unsigned char *bufend = ptr + bufsize + + cdef uint16_t benfields = (<uint16_t *>ptr)[0] + cdef int nfields = endian.be16toh(benfields) + ptr += sizeof(benfields) + cdef list row = PyList_New(nfields) + + cdef int col + cdef int32_t belength + cdef Py_ssize_t length + + for col in range(nfields): + memcpy(&belength, ptr, sizeof(belength)) + ptr += sizeof(belength) + if belength == _binary_null: + field = None + else: + length = endian.be32toh(belength) + if ptr + length > bufend: + raise e.DataError("bad copy data: length exceeding data") + field = PyMemoryView_FromObject( + ViewBuffer._from_buffer(data, ptr, length)) + ptr += length + + Py_INCREF(field) + PyList_SET_ITEM(row, col, field) + + return tx.load_sequence(row) + + +def parse_row_text(data, tx: Transformer) -> Tuple[Any, ...]: + cdef unsigned char *fstart + cdef Py_ssize_t size + _buffer_as_string_and_size(data, <char **>&fstart, &size) + + # politely assume that the number of fields will be what in the result + cdef int nfields = tx._nfields + cdef list row = PyList_New(nfields) + + cdef unsigned char *fend + cdef unsigned char *rowend = fstart + size + cdef unsigned char *src + cdef unsigned char *tgt + cdef int col + cdef int num_bs + + for col in range(nfields): + fend = fstart + num_bs = 0 + # Scan to the end of the field, remember if you see any backslash + while fend[0] != b'\t' and fend[0] != b'\n' and fend < rowend: + if fend[0] == b'\\': + num_bs += 1 + # skip the next char to avoid counting escaped backslashes twice + fend += 1 + fend += 1 + + # Check if we stopped for the right reason + if fend >= rowend: + raise e.DataError("bad copy data: field delimiter not found") + elif fend[0] == b'\t' and col == nfields - 1: + raise e.DataError("bad copy data: got a tab at the end of the row") + elif fend[0] == b'\n' and col != nfields - 1: + raise e.DataError( + "bad copy format: got a newline before the end of the row") + + # Is this a NULL? + if fend - fstart == 2 and fstart[0] == b'\\' and fstart[1] == b'N': + field = None + + # Is this a field with no backslash? + elif num_bs == 0: + # Nothing to unescape: we don't need a copy + field = PyMemoryView_FromObject( + ViewBuffer._from_buffer(data, fstart, fend - fstart)) + + # This is a field containing backslashes + else: + # We need a copy of the buffer to unescape + field = PyByteArray_FromStringAndSize("", 0) + PyByteArray_Resize(field, fend - fstart - num_bs) + tgt = <unsigned char *>PyByteArray_AS_STRING(field) + src = fstart + while (src < fend): + if src[0] != b'\\': + tgt[0] = src[0] + else: + src += 1 + tgt[0] = copy_unescape_lut[src[0]] + src += 1 + tgt += 1 + + Py_INCREF(field) + PyList_SET_ITEM(row, col, field) + + # Start of the field + fstart = fend + 1 + + # Convert the array of buffers into Python objects + return tx.load_sequence(row) + + +cdef extern from *: + """ +/* handle chars to (un)escape in text copy representation */ +/* '\b', '\t', '\n', '\v', '\f', '\r', '\\' */ + +/* Escaping chars */ +static const char copy_escape_lut[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 98, 116, 110, 118, 102, 114, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 92, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +/* Conversion of escaped to unescaped chars */ +static const char copy_unescape_lut[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 8, 99, 100, 101, 12, 103, 104, 105, 106, 107, 108, 109, 10, 111, +112, 113, 13, 115, 9, 117, 11, 119, 120, 121, 122, 123, 124, 125, 126, 127, +128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, +144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, +160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, +176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, +192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, +208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, +224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, +240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, +}; + """ + const char[256] copy_escape_lut + const char[256] copy_unescape_lut diff --git a/psycopg_c/psycopg_c/_psycopg/endian.pxd b/psycopg_c/psycopg_c/_psycopg/endian.pxd new file mode 100644 index 0000000..44e7305 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/endian.pxd @@ -0,0 +1,155 @@ +""" +Access to endian conversion function +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.stdint cimport uint16_t, uint32_t, uint64_t + +cdef extern from * nogil: + # from https://gist.github.com/panzi/6856583 + # Improved in: + # https://github.com/linux-sunxi/sunxi-tools/blob/master/include/portable_endian.h + """ +// "License": Public Domain +// I, Mathias Panzenböck, place this file hereby into the public domain. Use it at your own risk for whatever you like. +// In case there are jurisdictions that don't support putting things in the public domain you can also consider it to +// be "dual licensed" under the BSD, MIT and Apache licenses, if you want to. This code is trivial anyway. Consider it +// an example on how to get the endian conversion functions on different platforms. + +#ifndef PORTABLE_ENDIAN_H__ +#define PORTABLE_ENDIAN_H__ + +#if (defined(_WIN16) || defined(_WIN32) || defined(_WIN64)) && !defined(__WINDOWS__) + +# define __WINDOWS__ + +#endif + +#if defined(__linux__) || defined(__CYGWIN__) + +# include <endian.h> + +#elif defined(__APPLE__) + +# include <libkern/OSByteOrder.h> + +# define htobe16(x) OSSwapHostToBigInt16(x) +# define htole16(x) OSSwapHostToLittleInt16(x) +# define be16toh(x) OSSwapBigToHostInt16(x) +# define le16toh(x) OSSwapLittleToHostInt16(x) + +# define htobe32(x) OSSwapHostToBigInt32(x) +# define htole32(x) OSSwapHostToLittleInt32(x) +# define be32toh(x) OSSwapBigToHostInt32(x) +# define le32toh(x) OSSwapLittleToHostInt32(x) + +# define htobe64(x) OSSwapHostToBigInt64(x) +# define htole64(x) OSSwapHostToLittleInt64(x) +# define be64toh(x) OSSwapBigToHostInt64(x) +# define le64toh(x) OSSwapLittleToHostInt64(x) + +# define __BYTE_ORDER BYTE_ORDER +# define __BIG_ENDIAN BIG_ENDIAN +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __PDP_ENDIAN PDP_ENDIAN + +#elif defined(__OpenBSD__) || defined(__NetBSD__) || defined(__FreeBSD__) || defined(__DragonFly__) + +# include <sys/endian.h> + +/* For functions still missing, try to substitute 'historic' OpenBSD names */ +#ifndef be16toh +# define be16toh(x) betoh16(x) +#endif +#ifndef le16toh +# define le16toh(x) letoh16(x) +#endif +#ifndef be32toh +# define be32toh(x) betoh32(x) +#endif +#ifndef le32toh +# define le32toh(x) letoh32(x) +#endif +#ifndef be64toh +# define be64toh(x) betoh64(x) +#endif +#ifndef le64toh +# define le64toh(x) letoh64(x) +#endif + +#elif defined(__WINDOWS__) + +# include <winsock2.h> +# ifndef _MSC_VER +# include <sys/param.h> +# endif + +# if BYTE_ORDER == LITTLE_ENDIAN + +# define htobe16(x) htons(x) +# define htole16(x) (x) +# define be16toh(x) ntohs(x) +# define le16toh(x) (x) + +# define htobe32(x) htonl(x) +# define htole32(x) (x) +# define be32toh(x) ntohl(x) +# define le32toh(x) (x) + +# define htobe64(x) htonll(x) +# define htole64(x) (x) +# define be64toh(x) ntohll(x) +# define le64toh(x) (x) + +# elif BYTE_ORDER == BIG_ENDIAN + + /* that would be xbox 360 */ +# define htobe16(x) (x) +# define htole16(x) __builtin_bswap16(x) +# define be16toh(x) (x) +# define le16toh(x) __builtin_bswap16(x) + +# define htobe32(x) (x) +# define htole32(x) __builtin_bswap32(x) +# define be32toh(x) (x) +# define le32toh(x) __builtin_bswap32(x) + +# define htobe64(x) (x) +# define htole64(x) __builtin_bswap64(x) +# define be64toh(x) (x) +# define le64toh(x) __builtin_bswap64(x) + +# else + +# error byte order not supported + +# endif + +# define __BYTE_ORDER BYTE_ORDER +# define __BIG_ENDIAN BIG_ENDIAN +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __PDP_ENDIAN PDP_ENDIAN + +#else + +# error platform not supported + +#endif + +#endif + """ + cdef uint16_t htobe16(uint16_t host_16bits) + cdef uint16_t htole16(uint16_t host_16bits) + cdef uint16_t be16toh(uint16_t big_endian_16bits) + cdef uint16_t le16toh(uint16_t little_endian_16bits) + + cdef uint32_t htobe32(uint32_t host_32bits) + cdef uint32_t htole32(uint32_t host_32bits) + cdef uint32_t be32toh(uint32_t big_endian_32bits) + cdef uint32_t le32toh(uint32_t little_endian_32bits) + + cdef uint64_t htobe64(uint64_t host_64bits) + cdef uint64_t htole64(uint64_t host_64bits) + cdef uint64_t be64toh(uint64_t big_endian_64bits) + cdef uint64_t le64toh(uint64_t little_endian_64bits) diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx new file mode 100644 index 0000000..9ce9e54 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx @@ -0,0 +1,276 @@ +""" +C implementation of generators for the communication protocols with the libpq +""" + +# Copyright (C) 2020 The Psycopg Team + +from cpython.object cimport PyObject_CallFunctionObjArgs + +from typing import List + +from psycopg import errors as e +from psycopg.pq import abc, error_message +from psycopg.abc import PipelineCommand, PQGen +from psycopg._enums import Wait, Ready +from psycopg._compat import Deque +from psycopg._encodings import conninfo_encoding + +cdef object WAIT_W = Wait.W +cdef object WAIT_R = Wait.R +cdef object WAIT_RW = Wait.RW +cdef object PY_READY_R = Ready.R +cdef object PY_READY_W = Ready.W +cdef object PY_READY_RW = Ready.RW +cdef int READY_R = Ready.R +cdef int READY_W = Ready.W +cdef int READY_RW = Ready.RW + +def connect(conninfo: str) -> PQGenConn[abc.PGconn]: + """ + Generator to create a database connection without blocking. + + """ + cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode()) + cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr + cdef int conn_status = libpq.PQstatus(pgconn_ptr) + cdef int poll_status + + while True: + if conn_status == libpq.CONNECTION_BAD: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection is bad: {error_message(conn, encoding=encoding)}", + pgconn=conn + ) + + with nogil: + poll_status = libpq.PQconnectPoll(pgconn_ptr) + + if poll_status == libpq.PGRES_POLLING_OK: + break + elif poll_status == libpq.PGRES_POLLING_READING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_R) + elif poll_status == libpq.PGRES_POLLING_WRITING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_W) + elif poll_status == libpq.PGRES_POLLING_FAILED: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection failed: {error_message(conn, encoding=encoding)}", + pgconn=conn + ) + else: + raise e.InternalError( + f"unexpected poll status: {poll_status}", pgconn=conn + ) + + conn.nonblocking = 1 + return conn + + +def execute(pq.PGconn pgconn) -> PQGen[List[abc.PGresult]]: + """ + Generator sending a query and returning results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Return the list of results returned by the database (whether success + or error). + """ + yield from send(pgconn) + rv = yield from fetch_many(pgconn) + return rv + + +def send(pq.PGconn pgconn) -> PQGen[None]: + """ + Generator to send a query to the server without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + After this generator has finished you may want to cycle using `fetch()` + to retrieve the results available. + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int status + cdef int cires + + while True: + if pgconn.flush() == 0: + break + + status = yield WAIT_RW + if status & READY_R: + with nogil: + # This call may read notifies which will be saved in the + # PGconn buffer and passed to Python later. + cires = libpq.PQconsumeInput(pgconn_ptr) + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + + +def fetch_many(pq.PGconn pgconn) -> PQGen[List[PGresult]]: + """ + Generator retrieving results from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return the list of results returned by the database (whether success + or error). + """ + cdef list results = [] + cdef int status + cdef pq.PGresult result + cdef libpq.PGresult *pgres + + while True: + result = yield from fetch(pgconn) + if result is None: + break + results.append(result) + pgres = result._pgresult_ptr + + status = libpq.PQresultStatus(pgres) + if ( + status == libpq.PGRES_COPY_IN + or status == libpq.PGRES_COPY_OUT + or status == libpq.PGRES_COPY_BOTH + ): + # After entering copy mode the libpq will create a phony result + # for every request so let's break the endless loop. + break + + if status == libpq.PGRES_PIPELINE_SYNC: + # PIPELINE_SYNC is not followed by a NULL, but we return it alone + # similarly to other result sets. + break + + return results + + +def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]: + """ + Generator retrieving a single result from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return a result from the database (whether success or error). + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int cires, ibres + cdef libpq.PGresult *pgres + + with nogil: + ibres = libpq.PQisBusy(pgconn_ptr) + if ibres: + yield WAIT_R + while True: + with nogil: + cires = libpq.PQconsumeInput(pgconn_ptr) + if cires == 1: + ibres = libpq.PQisBusy(pgconn_ptr) + + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + if not ibres: + break + yield WAIT_R + + _consume_notifies(pgconn) + + with nogil: + pgres = libpq.PQgetResult(pgconn_ptr) + if pgres is NULL: + return None + return pq.PGresult._from_ptr(pgres) + + +def pipeline_communicate( + pq.PGconn pgconn, commands: Deque[PipelineCommand] +) -> PQGen[List[List[PGresult]]]: + """Generator to send queries from a connection in pipeline mode while also + receiving results. + + Return a list results, including single PIPELINE_SYNC elements. + """ + cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr + cdef int cires + cdef int status + cdef int ready + cdef libpq.PGresult *pgres + cdef list res = [] + cdef list results = [] + cdef pq.PGresult r + + while True: + ready = yield WAIT_RW + + if ready & READY_R: + with nogil: + cires = libpq.PQconsumeInput(pgconn_ptr) + if 1 != cires: + raise e.OperationalError( + f"consuming input failed: {error_message(pgconn)}") + + _consume_notifies(pgconn) + + res: List[PGresult] = [] + while True: + with nogil: + ibres = libpq.PQisBusy(pgconn_ptr) + if ibres: + break + pgres = libpq.PQgetResult(pgconn_ptr) + + if pgres is NULL: + if not res: + break + results.append(res) + res = [] + else: + status = libpq.PQresultStatus(pgres) + r = pq.PGresult._from_ptr(pgres) + if status == libpq.PGRES_PIPELINE_SYNC: + results.append([r]) + break + else: + res.append(r) + + if ready & READY_W: + pgconn.flush() + if not commands: + break + commands.popleft()() + + return results + + +cdef int _consume_notifies(pq.PGconn pgconn) except -1: + cdef object notify_handler = pgconn.notify_handler + cdef libpq.PGconn *pgconn_ptr + cdef libpq.PGnotify *notify + + if notify_handler is not None: + while True: + pynotify = pgconn.notifies() + if pynotify is None: + break + PyObject_CallFunctionObjArgs( + notify_handler, <PyObject *>pynotify, NULL + ) + else: + pgconn_ptr = pgconn._pgconn_ptr + while True: + notify = libpq.PQnotifies(pgconn_ptr) + if notify is NULL: + break + libpq.PQfreemem(notify) + + return 0 diff --git a/psycopg_c/psycopg_c/_psycopg/oids.pxd b/psycopg_c/psycopg_c/_psycopg/oids.pxd new file mode 100644 index 0000000..2a864c4 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/oids.pxd @@ -0,0 +1,92 @@ +""" +Constants to refer to OIDS in C +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use tools/update_oids.py to update this data. + +cdef enum: + INVALID_OID = 0 + + # autogenerated: start + + # Generated from PostgreSQL 15.0 + + ACLITEM_OID = 1033 + BIT_OID = 1560 + BOOL_OID = 16 + BOX_OID = 603 + BPCHAR_OID = 1042 + BYTEA_OID = 17 + CHAR_OID = 18 + CID_OID = 29 + CIDR_OID = 650 + CIRCLE_OID = 718 + DATE_OID = 1082 + DATEMULTIRANGE_OID = 4535 + DATERANGE_OID = 3912 + FLOAT4_OID = 700 + FLOAT8_OID = 701 + GTSVECTOR_OID = 3642 + INET_OID = 869 + INT2_OID = 21 + INT2VECTOR_OID = 22 + INT4_OID = 23 + INT4MULTIRANGE_OID = 4451 + INT4RANGE_OID = 3904 + INT8_OID = 20 + INT8MULTIRANGE_OID = 4536 + INT8RANGE_OID = 3926 + INTERVAL_OID = 1186 + JSON_OID = 114 + JSONB_OID = 3802 + JSONPATH_OID = 4072 + LINE_OID = 628 + LSEG_OID = 601 + MACADDR_OID = 829 + MACADDR8_OID = 774 + MONEY_OID = 790 + NAME_OID = 19 + NUMERIC_OID = 1700 + NUMMULTIRANGE_OID = 4532 + NUMRANGE_OID = 3906 + OID_OID = 26 + OIDVECTOR_OID = 30 + PATH_OID = 602 + PG_LSN_OID = 3220 + POINT_OID = 600 + POLYGON_OID = 604 + RECORD_OID = 2249 + REFCURSOR_OID = 1790 + REGCLASS_OID = 2205 + REGCOLLATION_OID = 4191 + REGCONFIG_OID = 3734 + REGDICTIONARY_OID = 3769 + REGNAMESPACE_OID = 4089 + REGOPER_OID = 2203 + REGOPERATOR_OID = 2204 + REGPROC_OID = 24 + REGPROCEDURE_OID = 2202 + REGROLE_OID = 4096 + REGTYPE_OID = 2206 + TEXT_OID = 25 + TID_OID = 27 + TIME_OID = 1083 + TIMESTAMP_OID = 1114 + TIMESTAMPTZ_OID = 1184 + TIMETZ_OID = 1266 + TSMULTIRANGE_OID = 4533 + TSQUERY_OID = 3615 + TSRANGE_OID = 3908 + TSTZMULTIRANGE_OID = 4534 + TSTZRANGE_OID = 3910 + TSVECTOR_OID = 3614 + TXID_SNAPSHOT_OID = 2970 + UUID_OID = 2950 + VARBIT_OID = 1562 + VARCHAR_OID = 1043 + XID_OID = 28 + XID8_OID = 5069 + XML_OID = 142 + # autogenerated: end diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx new file mode 100644 index 0000000..fc69725 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -0,0 +1,640 @@ +""" +Helper object to transform values between Python and PostgreSQL + +Cython implementation: can access to lower level C features without creating +too many temporary Python objects and performing less memory copying. + +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.ref cimport Py_INCREF, Py_DECREF +from cpython.set cimport PySet_Add, PySet_Contains +from cpython.dict cimport PyDict_GetItem, PyDict_SetItem +from cpython.list cimport ( + PyList_New, PyList_CheckExact, + PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE) +from cpython.bytes cimport PyBytes_AS_STRING +from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM +from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs + +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from psycopg import errors as e +from psycopg.pq import Format as PqFormat +from psycopg.rows import Row, RowMaker +from psycopg._encodings import pgconn_encoding + +NoneType = type(None) + +# internal structure: you are not supposed to know this. But it's worth some +# 10% of the innermost loop, so I'm willing to ask for forgiveness later... + +ctypedef struct PGresAttValue: + int len + char *value + +ctypedef struct pg_result_int: + # NOTE: it would be advised that we don't know this structure's content + int ntups + int numAttributes + libpq.PGresAttDesc *attDescs + PGresAttValue **tuples + # ...more members, which we ignore + + +@cython.freelist(16) +cdef class RowLoader: + cdef CLoader cloader + cdef object pyloader + cdef object loadfunc + + +@cython.freelist(16) +cdef class RowDumper: + cdef CDumper cdumper + cdef object pydumper + cdef object dumpfunc + cdef object oid + cdef object format + + +cdef class Transformer: + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that attributes + such as the server version or the connection encoding will not change. The + object have its state so adapting several values of the same type can be + optimised. + + """ + + cdef readonly object connection + cdef readonly object adapters + cdef readonly object types + cdef readonly object formats + cdef str _encoding + cdef int _none_oid + + # mapping class -> Dumper instance (auto, text, binary) + cdef dict _auto_dumpers + cdef dict _text_dumpers + cdef dict _binary_dumpers + + # mapping oid -> Loader instance (text, binary) + cdef dict _text_loaders + cdef dict _binary_loaders + + # mapping oid -> Dumper instance (text, binary) + cdef dict _oid_text_dumpers + cdef dict _oid_binary_dumpers + + cdef pq.PGresult _pgresult + cdef int _nfields, _ntuples + cdef list _row_dumpers + cdef list _row_loaders + + cdef dict _oid_types + + def __cinit__(self, context: Optional["AdaptContext"] = None): + if context is not None: + self.adapters = context.adapters + self.connection = context.connection + else: + from psycopg import postgres + self.adapters = postgres.adapters + self.connection = None + + self.types = self.formats = None + self._none_oid = -1 + + @classmethod + def from_context(cls, context: Optional["AdaptContext"]): + """ + Return a Transformer from an AdaptContext. + + If the context is a Transformer instance, just return it. + """ + return _tx_from_context(context) + + @property + def encoding(self) -> str: + if not self._encoding: + conn = self.connection + self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" + return self._encoding + + @property + def pgresult(self) -> Optional[PGresult]: + return self._pgresult + + cpdef set_pgresult( + self, + pq.PGresult result, + object set_loaders = True, + object format = None + ): + self._pgresult = result + + if result is None: + self._nfields = self._ntuples = 0 + if set_loaders: + self._row_loaders = [] + return + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + self._nfields = libpq.PQnfields(res) + self._ntuples = libpq.PQntuples(res) + + if not set_loaders: + return + + if not self._nfields: + self._row_loaders = [] + return + + if format is None: + format = libpq.PQfformat(res, 0) + + cdef list loaders = PyList_New(self._nfields) + cdef PyObject *row_loader + cdef object oid + + cdef int i + for i in range(self._nfields): + oid = libpq.PQftype(res, i) + row_loader = self._c_get_loader(<PyObject *>oid, <PyObject *>format) + Py_INCREF(<object>row_loader) + PyList_SET_ITEM(loaders, i, <object>row_loader) + + self._row_loaders = loaders + + def set_dumper_types(self, types: Sequence[int], format: Format) -> None: + cdef Py_ssize_t ntypes = len(types) + dumpers = PyList_New(ntypes) + cdef int i + for i in range(ntypes): + oid = types[i] + dumper_ptr = self.get_dumper_by_oid( + <PyObject *>oid, <PyObject *>format) + Py_INCREF(<object>dumper_ptr) + PyList_SET_ITEM(dumpers, i, <object>dumper_ptr) + + self._row_dumpers = dumpers + self.types = tuple(types) + self.formats = [format] * ntypes + + def set_loader_types(self, types: Sequence[int], format: Format) -> None: + self._c_loader_types(len(types), types, format) + + cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, object format): + cdef list loaders = PyList_New(ntypes) + + # these are used more as Python object than C + cdef PyObject *oid + cdef PyObject *row_loader + for i in range(ntypes): + oid = PyList_GET_ITEM(types, i) + row_loader = self._c_get_loader(oid, <PyObject *>format) + Py_INCREF(<object>row_loader) + PyList_SET_ITEM(loaders, i, <object>row_loader) + + self._row_loaders = loaders + + cpdef as_literal(self, obj): + cdef PyObject *row_dumper = self.get_row_dumper( + <PyObject *>obj, <PyObject *>PG_TEXT) + + if (<RowDumper>row_dumper).cdumper is not None: + dumper = (<RowDumper>row_dumper).cdumper + else: + dumper = (<RowDumper>row_dumper).pydumper + + rv = dumper.quote(obj) + oid = dumper.oid + # If the result is quoted and the oid not unknown or text, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + if oid and oid != oids.TEXT_OID and rv and rv[-1] == 39: + if self._oid_types is None: + self._oid_types = {} + type_ptr = PyDict_GetItem(<object>self._oid_types, oid) + if type_ptr == NULL: + type_sql = b"" + ti = self.adapters.types.get(oid) + if ti is not None: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + + type_ptr = <PyObject *>type_sql + PyDict_SetItem(<object>self._oid_types, oid, type_sql) + + if <object>type_ptr: + rv = b"%s::%s" % (rv, <object>type_ptr) + + return rv + + def get_dumper(self, obj, format) -> "Dumper": + cdef PyObject *row_dumper = self.get_row_dumper( + <PyObject *>obj, <PyObject *>format) + return (<RowDumper>row_dumper).pydumper + + cdef PyObject *get_row_dumper(self, PyObject *obj, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowDumper for the given obj/fmt. + """ + # Fast path: return a Dumper class already instantiated from the same type + cdef PyObject *cache + cdef PyObject *ptr + cdef PyObject *ptr1 + cdef RowDumper row_dumper + + # Normally, the type of the object dictates how to dump it + key = type(<object>obj) + + # Establish where would the dumper be cached + bfmt = PyUnicode_AsUTF8String(<object>fmt) + cdef char cfmt = PyBytes_AS_STRING(bfmt)[0] + if cfmt == b's': + if self._auto_dumpers is None: + self._auto_dumpers = {} + cache = <PyObject *>self._auto_dumpers + elif cfmt == b'b': + if self._binary_dumpers is None: + self._binary_dumpers = {} + cache = <PyObject *>self._binary_dumpers + elif cfmt == b't': + if self._text_dumpers is None: + self._text_dumpers = {} + cache = <PyObject *>self._text_dumpers + else: + raise ValueError( + f"format should be a psycopg.adapt.Format, not {<object>fmt}") + + # Reuse an existing Dumper class for objects of the same type + ptr = PyDict_GetItem(<object>cache, key) + if ptr == NULL: + dcls = PyObject_CallFunctionObjArgs( + self.adapters.get_dumper, <PyObject *>key, fmt, NULL) + dumper = PyObject_CallFunctionObjArgs( + dcls, <PyObject *>key, <PyObject *>self, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, key, row_dumper) + ptr = <PyObject *>row_dumper + + # Check if the dumper requires an upgrade to handle this specific value + if (<RowDumper>ptr).cdumper is not None: + key1 = (<RowDumper>ptr).cdumper.get_key(<object>obj, <object>fmt) + else: + key1 = PyObject_CallFunctionObjArgs( + (<RowDumper>ptr).pydumper.get_key, obj, fmt, NULL) + if key1 is key: + return ptr + + # If it does, ask the dumper to create its own upgraded version + ptr1 = PyDict_GetItem(<object>cache, key1) + if ptr1 != NULL: + return ptr1 + + if (<RowDumper>ptr).cdumper is not None: + dumper = (<RowDumper>ptr).cdumper.upgrade(<object>obj, <object>fmt) + else: + dumper = PyObject_CallFunctionObjArgs( + (<RowDumper>ptr).pydumper.upgrade, obj, fmt, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, key1, row_dumper) + return <PyObject *>row_dumper + + cdef PyObject *get_dumper_by_oid(self, PyObject *oid, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowDumper for the given oid/fmt. + """ + cdef PyObject *ptr + cdef PyObject *cache + cdef RowDumper row_dumper + + # Establish where would the dumper be cached + cdef int cfmt = <object>fmt + if cfmt == 0: + if self._oid_text_dumpers is None: + self._oid_text_dumpers = {} + cache = <PyObject *>self._oid_text_dumpers + elif cfmt == 1: + if self._oid_binary_dumpers is None: + self._oid_binary_dumpers = {} + cache = <PyObject *>self._oid_binary_dumpers + else: + raise ValueError( + f"format should be a psycopg.pq.Format, not {<object>fmt}") + + # Reuse an existing Dumper class for objects of the same type + ptr = PyDict_GetItem(<object>cache, <object>oid) + if ptr == NULL: + dcls = PyObject_CallFunctionObjArgs( + self.adapters.get_dumper_by_oid, oid, fmt, NULL) + dumper = PyObject_CallFunctionObjArgs( + dcls, <PyObject *>NoneType, <PyObject *>self, NULL) + + row_dumper = _as_row_dumper(dumper) + PyDict_SetItem(<object>cache, <object>oid, row_dumper) + ptr = <PyObject *>row_dumper + + return ptr + + cpdef dump_sequence(self, object params, object formats): + # Verify that they are not none and that PyList_GET_ITEM won't blow up + cdef Py_ssize_t nparams = len(params) + cdef list out = PyList_New(nparams) + + cdef int i + cdef PyObject *dumper_ptr # borrowed pointer to row dumper + cdef object dumped + cdef Py_ssize_t size + + if self._none_oid < 0: + self._none_oid = self.adapters.get_dumper(NoneType, "s").oid + + dumpers = self._row_dumpers + + if dumpers: + for i in range(nparams): + param = params[i] + if param is not None: + dumper_ptr = PyList_GET_ITEM(dumpers, i) + if (<RowDumper>dumper_ptr).cdumper is not None: + dumped = PyByteArray_FromStringAndSize("", 0) + size = (<RowDumper>dumper_ptr).cdumper.cdump( + param, <bytearray>dumped, 0) + PyByteArray_Resize(dumped, size) + else: + dumped = PyObject_CallFunctionObjArgs( + (<RowDumper>dumper_ptr).dumpfunc, + <PyObject *>param, NULL) + else: + dumped = None + + Py_INCREF(dumped) + PyList_SET_ITEM(out, i, dumped) + + return out + + cdef tuple types = PyTuple_New(nparams) + cdef list pqformats = PyList_New(nparams) + + for i in range(nparams): + param = params[i] + if param is not None: + dumper_ptr = self.get_row_dumper( + <PyObject *>param, <PyObject *>formats[i]) + if (<RowDumper>dumper_ptr).cdumper is not None: + dumped = PyByteArray_FromStringAndSize("", 0) + size = (<RowDumper>dumper_ptr).cdumper.cdump( + param, <bytearray>dumped, 0) + PyByteArray_Resize(dumped, size) + else: + dumped = PyObject_CallFunctionObjArgs( + (<RowDumper>dumper_ptr).dumpfunc, + <PyObject *>param, NULL) + oid = (<RowDumper>dumper_ptr).oid + fmt = (<RowDumper>dumper_ptr).format + else: + dumped = None + oid = self._none_oid + fmt = PQ_TEXT + + Py_INCREF(dumped) + PyList_SET_ITEM(out, i, dumped) + + Py_INCREF(oid) + PyTuple_SET_ITEM(types, i, oid) + + Py_INCREF(fmt) + PyList_SET_ITEM(pqformats, i, fmt) + + self.types = types + self.formats = pqformats + return out + + def load_rows(self, int row0, int row1, object make_row) -> List[Row]: + if self._pgresult is None: + raise e.InterfaceError("result not set") + + if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): + raise e.InterfaceError( + f"rows must be included between 0 and {self._ntuples}" + ) + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + # cheeky access to the internal PGresult structure + cdef pg_result_int *ires = <pg_result_int*>res + + cdef int row + cdef int col + cdef PGresAttValue *attval + cdef object record # not 'tuple' as it would check on assignment + + cdef object records = PyList_New(row1 - row0) + for row in range(row0, row1): + record = PyTuple_New(self._nfields) + Py_INCREF(record) + PyList_SET_ITEM(records, row - row0, record) + + cdef PyObject *loader # borrowed RowLoader + cdef PyObject *brecord # borrowed + row_loaders = self._row_loaders # avoid an incref/decref per item + + for col in range(self._nfields): + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + for row in range(row0, row1): + brecord = PyList_GET_ITEM(records, row - row0) + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + pyval = (<RowLoader>loader).cloader.cload( + attval.value, attval.len) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(<object>brecord, col, pyval) + + else: + for row in range(row0, row1): + brecord = PyList_GET_ITEM(records, row - row0) + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + <unsigned char *>attval.value, attval.len)) + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>b, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(<object>brecord, col, pyval) + + if make_row is not tuple: + for i in range(row1 - row0): + brecord = PyList_GET_ITEM(records, i) + record = PyObject_CallFunctionObjArgs( + make_row, <PyObject *>brecord, NULL) + Py_INCREF(record) + PyList_SET_ITEM(records, i, record) + Py_DECREF(<object>brecord) + return records + + def load_row(self, int row, object make_row) -> Optional[Row]: + if self._pgresult is None: + return None + + if not 0 <= row < self._ntuples: + return None + + cdef libpq.PGresult *res = self._pgresult._pgresult_ptr + # cheeky access to the internal PGresult structure + cdef pg_result_int *ires = <pg_result_int*>res + + cdef PyObject *loader # borrowed RowLoader + cdef int col + cdef PGresAttValue *attval + cdef object record # not 'tuple' as it would check on assignment + + record = PyTuple_New(self._nfields) + row_loaders = self._row_loaders # avoid an incref/decref per item + + for col in range(self._nfields): + attval = &(ires.tuples[row][col]) + if attval.len == -1: # NULL_LEN + pyval = None + else: + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + pyval = (<RowLoader>loader).cloader.cload( + attval.value, attval.len) + else: + b = PyMemoryView_FromObject( + ViewBuffer._from_buffer( + self._pgresult, + <unsigned char *>attval.value, attval.len)) + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>b, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(record, col, pyval) + + if make_row is not tuple: + record = PyObject_CallFunctionObjArgs( + make_row, <PyObject *>record, NULL) + return record + + cpdef object load_sequence(self, record: Sequence[Optional[Buffer]]): + cdef Py_ssize_t nfields = len(record) + out = PyTuple_New(nfields) + cdef PyObject *loader # borrowed RowLoader + cdef int col + cdef char *ptr + cdef Py_ssize_t size + + row_loaders = self._row_loaders # avoid an incref/decref per item + if PyList_GET_SIZE(row_loaders) != nfields: + raise e.ProgrammingError( + f"cannot load sequence of {nfields} items:" + f" {len(self._row_loaders)} loaders registered") + + for col in range(nfields): + item = record[col] + if item is None: + Py_INCREF(None) + PyTuple_SET_ITEM(out, col, None) + continue + + loader = PyList_GET_ITEM(row_loaders, col) + if (<RowLoader>loader).cloader is not None: + _buffer_as_string_and_size(item, &ptr, &size) + pyval = (<RowLoader>loader).cloader.cload(ptr, size) + else: + pyval = PyObject_CallFunctionObjArgs( + (<RowLoader>loader).loadfunc, <PyObject *>item, NULL) + + Py_INCREF(pyval) + PyTuple_SET_ITEM(out, col, pyval) + + return out + + def get_loader(self, oid: int, format: pq.Format) -> "Loader": + cdef PyObject *row_loader = self._c_get_loader( + <PyObject *>oid, <PyObject *>format) + return (<RowLoader>row_loader).pyloader + + cdef PyObject *_c_get_loader(self, PyObject *oid, PyObject *fmt) except NULL: + """ + Return a borrowed reference to the RowLoader instance for given oid/fmt + """ + cdef PyObject *ptr + cdef PyObject *cache + + if <object>fmt == PQ_TEXT: + if self._text_loaders is None: + self._text_loaders = {} + cache = <PyObject *>self._text_loaders + elif <object>fmt == PQ_BINARY: + if self._binary_loaders is None: + self._binary_loaders = {} + cache = <PyObject *>self._binary_loaders + else: + raise ValueError( + f"format should be a psycopg.pq.Format, not {format}") + + ptr = PyDict_GetItem(<object>cache, <object>oid) + if ptr != NULL: + return ptr + + loader_cls = self.adapters.get_loader(<object>oid, <object>fmt) + if loader_cls is None: + loader_cls = self.adapters.get_loader(oids.INVALID_OID, <object>fmt) + if loader_cls is None: + raise e.InterfaceError("unknown oid loader not found") + + loader = PyObject_CallFunctionObjArgs( + loader_cls, oid, <PyObject *>self, NULL) + + cdef RowLoader row_loader = RowLoader() + row_loader.pyloader = loader + row_loader.loadfunc = loader.load + if isinstance(loader, CLoader): + row_loader.cloader = <CLoader>loader + + PyDict_SetItem(<object>cache, <object>oid, row_loader) + return <PyObject *>row_loader + + +cdef object _as_row_dumper(object dumper): + cdef RowDumper row_dumper = RowDumper() + + row_dumper.pydumper = dumper + row_dumper.dumpfunc = dumper.dump + row_dumper.oid = dumper.oid + row_dumper.format = dumper.format + + if isinstance(dumper, CDumper): + row_dumper.cdumper = <CDumper>dumper + + return row_dumper + + +cdef Transformer _tx_from_context(object context): + if isinstance(context, Transformer): + return context + else: + return Transformer(context) diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx new file mode 100644 index 0000000..0af6c57 --- /dev/null +++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx @@ -0,0 +1,197 @@ +""" +C implementation of waiting functions +""" + +# Copyright (C) 2022 The Psycopg Team + +from cpython.object cimport PyObject_CallFunctionObjArgs + +cdef extern from *: + """ +#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL) + +#if defined(HAVE_POLL_H) +#include <poll.h> +#elif defined(HAVE_SYS_POLL_H) +#include <sys/poll.h> +#endif + +#else /* no poll available */ + +#ifdef MS_WINDOWS +#include <winsock2.h> +#else +#include <sys/select.h> +#endif + +#endif /* HAVE_POLL */ + +#define SELECT_EV_READ 1 +#define SELECT_EV_WRITE 2 + +#define SEC_TO_MS 1000 +#define SEC_TO_US (1000 * 1000) + +/* Use select to wait for readiness on fileno. + * + * - Return SELECT_EV_* if the file is ready + * - Return 0 on timeout + * - Return -1 (and set an exception) on error. + * + * The wisdom of this function comes from: + * + * - PostgreSQL libpq (see src/interfaces/libpq/fe-misc.c) + * - Python select module (see Modules/selectmodule.c) + */ +static int +wait_c_impl(int fileno, int wait, float timeout) +{ + int select_rv; + int rv = 0; + +#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL) + + struct pollfd input_fd; + int timeout_ms; + + input_fd.fd = fileno; + input_fd.events = POLLERR; + input_fd.revents = 0; + + if (wait & SELECT_EV_READ) { input_fd.events |= POLLIN; } + if (wait & SELECT_EV_WRITE) { input_fd.events |= POLLOUT; } + + if (timeout < 0.0) { + timeout_ms = -1; + } else { + timeout_ms = (int)(timeout * SEC_TO_MS); + } + + Py_BEGIN_ALLOW_THREADS + errno = 0; + select_rv = poll(&input_fd, 1, timeout_ms); + Py_END_ALLOW_THREADS + + if (PyErr_CheckSignals()) { goto finally; } + + if (select_rv < 0) { + goto error; + } + + if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; } + if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; } + +#else + + fd_set ifds; + fd_set ofds; + fd_set efds; + struct timeval tv, *tvptr; + +#ifndef MS_WINDOWS + if (fileno >= 1024) { + PyErr_SetString( + PyExc_ValueError, /* same exception of Python's 'select.select()' */ + "connection file descriptor out of range for 'select()'"); + return -1; + } +#endif + + FD_ZERO(&ifds); + FD_ZERO(&ofds); + FD_ZERO(&efds); + + if (wait & SELECT_EV_READ) { FD_SET(fileno, &ifds); } + if (wait & SELECT_EV_WRITE) { FD_SET(fileno, &ofds); } + FD_SET(fileno, &efds); + + /* Compute appropriate timeout interval */ + if (timeout < 0.0) { + tvptr = NULL; + } + else { + tv.tv_sec = (int)timeout; + tv.tv_usec = (int)(((long)timeout * SEC_TO_US) % SEC_TO_US); + tvptr = &tv; + } + + Py_BEGIN_ALLOW_THREADS + errno = 0; + select_rv = select(fileno + 1, &ifds, &ofds, &efds, tvptr); + Py_END_ALLOW_THREADS + + if (PyErr_CheckSignals()) { goto finally; } + + if (select_rv < 0) { + goto error; + } + + if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; } + if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; } + +#endif /* HAVE_POLL */ + + return rv; + +error: + +#ifdef MS_WINDOWS + if (select_rv == SOCKET_ERROR) { + PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); + } +#else + if (select_rv < 0) { + PyErr_SetFromErrno(PyExc_OSError); + } +#endif + else { + PyErr_SetString(PyExc_OSError, "unexpected error from select()"); + } + +finally: + + return -1; + +} + """ + cdef int wait_c_impl(int fileno, int wait, float timeout) except -1 + + +def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV: + """ + Wait for a generator using poll or select. + """ + cdef float ctimeout + cdef int wait, ready + cdef PyObject *pyready + + if timeout is None: + ctimeout = -1.0 + else: + ctimeout = float(timeout) + if ctimeout < 0.0: + ctimeout = -1.0 + + send = gen.send + + try: + wait = next(gen) + + while True: + ready = wait_c_impl(fileno, wait, ctimeout) + if ready == 0: + continue + elif ready == READY_R: + pyready = <PyObject *>PY_READY_R + elif ready == READY_RW: + pyready = <PyObject *>PY_READY_RW + elif ready == READY_W: + pyready = <PyObject *>PY_READY_W + else: + raise AssertionError(f"unexpected ready value: {ready}") + + wait = PyObject_CallFunctionObjArgs(send, pyready, NULL) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd new file mode 100644 index 0000000..57825dd --- /dev/null +++ b/psycopg_c/psycopg_c/pq.pxd @@ -0,0 +1,78 @@ +# Include pid_t but Windows doesn't have it +# Don't use "IF" so that the generated C is portable and can be included +# in the sdist. +cdef extern from * nogil: + """ +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + typedef signed pid_t; +#else + #include <fcntl.h> +#endif + """ + ctypedef signed pid_t + +from psycopg_c.pq cimport libpq + +ctypedef char *(*conn_bytes_f) (const libpq.PGconn *) +ctypedef int(*conn_int_f) (const libpq.PGconn *) + + +cdef class PGconn: + cdef libpq.PGconn* _pgconn_ptr + cdef object __weakref__ + cdef public object notice_handler + cdef public object notify_handler + cdef pid_t _procpid + + @staticmethod + cdef PGconn _from_ptr(libpq.PGconn *ptr) + + cpdef int flush(self) except -1 + cpdef object notifies(self) + + +cdef class PGresult: + cdef libpq.PGresult* _pgresult_ptr + + @staticmethod + cdef PGresult _from_ptr(libpq.PGresult *ptr) + + +cdef class PGcancel: + cdef libpq.PGcancel* pgcancel_ptr + + @staticmethod + cdef PGcancel _from_ptr(libpq.PGcancel *ptr) + + +cdef class Escaping: + cdef PGconn conn + + cpdef escape_literal(self, data) + cpdef escape_identifier(self, data) + cpdef escape_string(self, data) + cpdef escape_bytea(self, data) + cpdef unescape_bytea(self, const unsigned char *data) + + +cdef class PQBuffer: + cdef unsigned char *buf + cdef Py_ssize_t len + + @staticmethod + cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length) + + +cdef class ViewBuffer: + cdef unsigned char *buf + cdef Py_ssize_t len + cdef object obj + + @staticmethod + cdef ViewBuffer _from_buffer( + object obj, unsigned char *buf, Py_ssize_t length) + + +cdef int _buffer_as_string_and_size( + data: "Buffer", char **ptr, Py_ssize_t *length +) except -1 diff --git a/psycopg_c/psycopg_c/pq.pyx b/psycopg_c/psycopg_c/pq.pyx new file mode 100644 index 0000000..d397c17 --- /dev/null +++ b/psycopg_c/psycopg_c/pq.pyx @@ -0,0 +1,38 @@ +""" +libpq Python wrapper using cython bindings. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c.pq cimport libpq + +import logging + +from psycopg import errors as e +from psycopg.pq import Format +from psycopg.pq.misc import error_message + +logger = logging.getLogger("psycopg") + +__impl__ = 'c' +__build_version__ = libpq.PG_VERSION_NUM + + +def version(): + return libpq.PQlibVersion() + + +include "pq/pgconn.pyx" +include "pq/pgresult.pyx" +include "pq/pgcancel.pyx" +include "pq/conninfo.pyx" +include "pq/escaping.pyx" +include "pq/pqbuffer.pyx" + + +# importing the ssl module sets up Python's libcrypto callbacks +import ssl # noqa + +# disable libcrypto setup in libpq, so it won't stomp on the callbacks +# that have already been set up +libpq.PQinitOpenSSL(1, 0) diff --git a/psycopg_c/psycopg_c/pq/__init__.pxd b/psycopg_c/psycopg_c/pq/__init__.pxd new file mode 100644 index 0000000..ce8c60c --- /dev/null +++ b/psycopg_c/psycopg_c/pq/__init__.pxd @@ -0,0 +1,9 @@ +""" +psycopg_c.pq cython module. + +This file is necessary to allow c-importing pxd files from this directory. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg_c.pq cimport libpq diff --git a/psycopg_c/psycopg_c/pq/conninfo.pyx b/psycopg_c/psycopg_c/pq/conninfo.pyx new file mode 100644 index 0000000..3443de1 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/conninfo.pyx @@ -0,0 +1,61 @@ +""" +psycopg_c.pq.Conninfo object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +from psycopg.pq.misc import ConninfoOption + + +class Conninfo: + @classmethod + def get_defaults(cls) -> List[ConninfoOption]: + cdef libpq.PQconninfoOption *opts = libpq.PQconndefaults() + if opts is NULL : + raise MemoryError("couldn't allocate connection defaults") + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + @classmethod + def parse(cls, const char *conninfo) -> List[ConninfoOption]: + cdef char *errmsg = NULL + cdef libpq.PQconninfoOption *opts = libpq.PQconninfoParse(conninfo, &errmsg) + if opts is NULL: + if errmsg is NULL: + raise MemoryError("couldn't allocate on conninfo parse") + else: + exc = e.OperationalError(errmsg.decode("utf8", "replace")) + libpq.PQfreemem(errmsg) + raise exc + + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + def __repr__(self): + return f"<{type(self).__name__} ({self.keyword.decode('ascii')})>" + + +cdef _options_from_array(libpq.PQconninfoOption *opts): + rv = [] + cdef int i = 0 + cdef libpq.PQconninfoOption* opt + while True: + opt = opts + i + if opt.keyword is NULL: + break + rv.append( + ConninfoOption( + keyword=opt.keyword, + envvar=opt.envvar if opt.envvar is not NULL else None, + compiled=opt.compiled if opt.compiled is not NULL else None, + val=opt.val if opt.val is not NULL else None, + label=opt.label if opt.label is not NULL else None, + dispchar=opt.dispchar if opt.dispchar is not NULL else None, + dispsize=opt.dispsize, + ) + ) + i += 1 + + return rv diff --git a/psycopg_c/psycopg_c/pq/escaping.pyx b/psycopg_c/psycopg_c/pq/escaping.pyx new file mode 100644 index 0000000..f0a44d3 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/escaping.pyx @@ -0,0 +1,132 @@ +""" +psycopg_c.pq.Escaping object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +from libc.string cimport strlen +from cpython.mem cimport PyMem_Malloc, PyMem_Free + + +cdef class Escaping: + def __init__(self, PGconn conn = None): + self.conn = conn + + cpdef escape_literal(self, data): + cdef char *out + cdef char *ptr + cdef Py_ssize_t length + + if self.conn is None: + raise e.OperationalError("escape_literal failed: no connection provided") + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + _buffer_as_string_and_size(data, &ptr, &length) + + out = libpq.PQescapeLiteral(self.conn._pgconn_ptr, ptr, length) + if out is NULL: + raise e.OperationalError( + f"escape_literal failed: {error_message(self.conn)}" + ) + + rv = out[:strlen(out)] + libpq.PQfreemem(out) + return rv + + cpdef escape_identifier(self, data): + cdef char *out + cdef char *ptr + cdef Py_ssize_t length + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is None: + raise e.OperationalError("escape_identifier failed: no connection provided") + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + out = libpq.PQescapeIdentifier(self.conn._pgconn_ptr, ptr, length) + if out is NULL: + raise e.OperationalError( + f"escape_identifier failed: {error_message(self.conn)}" + ) + + rv = out[:strlen(out)] + libpq.PQfreemem(out) + return rv + + cpdef escape_string(self, data): + cdef int error + cdef size_t len_out + cdef char *ptr + cdef char *buf_out + cdef Py_ssize_t length + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is not None: + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + buf_out = <char *>PyMem_Malloc(length * 2 + 1) + len_out = libpq.PQescapeStringConn( + self.conn._pgconn_ptr, buf_out, ptr, length, &error + ) + if error: + PyMem_Free(buf_out) + raise e.OperationalError( + f"escape_string failed: {error_message(self.conn)}" + ) + + else: + buf_out = <char *>PyMem_Malloc(length * 2 + 1) + len_out = libpq.PQescapeString(buf_out, ptr, length) + + rv = buf_out[:len_out] + PyMem_Free(buf_out) + return rv + + cpdef escape_bytea(self, data): + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + + if self.conn is not None and self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + _buffer_as_string_and_size(data, &ptr, &length) + + if self.conn is not None: + out = libpq.PQescapeByteaConn( + self.conn._pgconn_ptr, <unsigned char *>ptr, length, &len_out) + else: + out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out - 1] # out includes final 0 + libpq.PQfreemem(out) + return rv + + cpdef unescape_bytea(self, const unsigned char *data): + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn is not None: + if self.conn._pgconn_ptr is NULL: + raise e.OperationalError("the connection is closed") + + cdef size_t len_out + cdef unsigned char *out = libpq.PQunescapeBytea(data, &len_out) + if out is NULL: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out] + libpq.PQfreemem(out) + return rv diff --git a/psycopg_c/psycopg_c/pq/libpq.pxd b/psycopg_c/psycopg_c/pq/libpq.pxd new file mode 100644 index 0000000..5e05e40 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/libpq.pxd @@ -0,0 +1,321 @@ +""" +Libpq header definition for the cython psycopg.pq implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cdef extern from "stdio.h": + + ctypedef struct FILE: + pass + +cdef extern from "pg_config.h": + + int PG_VERSION_NUM + + +cdef extern from "libpq-fe.h": + + # structures and types + + ctypedef unsigned int Oid + + ctypedef struct PGconn: + pass + + ctypedef struct PGresult: + pass + + ctypedef struct PQconninfoOption: + char *keyword + char *envvar + char *compiled + char *val + char *label + char *dispchar + int dispsize + + ctypedef struct PGnotify: + char *relname + int be_pid + char *extra + + ctypedef struct PGcancel: + pass + + ctypedef struct PGresAttDesc: + char *name + Oid tableid + int columnid + int format + Oid typid + int typlen + int atttypmod + + # enums + + ctypedef enum PostgresPollingStatusType: + PGRES_POLLING_FAILED = 0 + PGRES_POLLING_READING + PGRES_POLLING_WRITING + PGRES_POLLING_OK + PGRES_POLLING_ACTIVE + + + ctypedef enum PGPing: + PQPING_OK + PQPING_REJECT + PQPING_NO_RESPONSE + PQPING_NO_ATTEMPT + + ctypedef enum ConnStatusType: + CONNECTION_OK + CONNECTION_BAD + CONNECTION_STARTED + CONNECTION_MADE + CONNECTION_AWAITING_RESPONSE + CONNECTION_AUTH_OK + CONNECTION_SETENV + CONNECTION_SSL_STARTUP + CONNECTION_NEEDED + CONNECTION_CHECK_WRITABLE + CONNECTION_GSS_STARTUP + # CONNECTION_CHECK_TARGET PG 12 + + ctypedef enum PGTransactionStatusType: + PQTRANS_IDLE + PQTRANS_ACTIVE + PQTRANS_INTRANS + PQTRANS_INERROR + PQTRANS_UNKNOWN + + ctypedef enum ExecStatusType: + PGRES_EMPTY_QUERY = 0 + PGRES_COMMAND_OK + PGRES_TUPLES_OK + PGRES_COPY_OUT + PGRES_COPY_IN + PGRES_BAD_RESPONSE + PGRES_NONFATAL_ERROR + PGRES_FATAL_ERROR + PGRES_COPY_BOTH + PGRES_SINGLE_TUPLE + PGRES_PIPELINE_SYNC + PGRES_PIPELINE_ABORT + + # 33.1. Database Connection Control Functions + PGconn *PQconnectdb(const char *conninfo) + PGconn *PQconnectStart(const char *conninfo) + PostgresPollingStatusType PQconnectPoll(PGconn *conn) nogil + PQconninfoOption *PQconndefaults() + PQconninfoOption *PQconninfo(PGconn *conn) + PQconninfoOption *PQconninfoParse(const char *conninfo, char **errmsg) + void PQfinish(PGconn *conn) + void PQreset(PGconn *conn) + int PQresetStart(PGconn *conn) + PostgresPollingStatusType PQresetPoll(PGconn *conn) + PGPing PQping(const char *conninfo) + + # 33.2. Connection Status Functions + char *PQdb(const PGconn *conn) + char *PQuser(const PGconn *conn) + char *PQpass(const PGconn *conn) + char *PQhost(const PGconn *conn) + char *PQhostaddr(const PGconn *conn) + char *PQport(const PGconn *conn) + char *PQtty(const PGconn *conn) + char *PQoptions(const PGconn *conn) + ConnStatusType PQstatus(const PGconn *conn) + PGTransactionStatusType PQtransactionStatus(const PGconn *conn) + const char *PQparameterStatus(const PGconn *conn, const char *paramName) + int PQprotocolVersion(const PGconn *conn) + int PQserverVersion(const PGconn *conn) + char *PQerrorMessage(const PGconn *conn) + int PQsocket(const PGconn *conn) nogil + int PQbackendPID(const PGconn *conn) + int PQconnectionNeedsPassword(const PGconn *conn) + int PQconnectionUsedPassword(const PGconn *conn) + int PQsslInUse(PGconn *conn) # TODO: const in PG 12 docs - verify/report + # TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl + + # 33.3. Command Execution Functions + PGresult *PQexec(PGconn *conn, const char *command) nogil + PGresult *PQexecParams(PGconn *conn, + const char *command, + int nParams, + const Oid *paramTypes, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + PGresult *PQprepare(PGconn *conn, + const char *stmtName, + const char *query, + int nParams, + const Oid *paramTypes) nogil + PGresult *PQexecPrepared(PGconn *conn, + const char *stmtName, + int nParams, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + PGresult *PQdescribePrepared(PGconn *conn, const char *stmtName) nogil + PGresult *PQdescribePortal(PGconn *conn, const char *portalName) nogil + ExecStatusType PQresultStatus(const PGresult *res) nogil + # PQresStatus: not needed, we have pretty enums + char *PQresultErrorMessage(const PGresult *res) nogil + # TODO: PQresultVerboseErrorMessage + char *PQresultErrorField(const PGresult *res, int fieldcode) nogil + void PQclear(PGresult *res) nogil + + # 33.3.2. Retrieving Query Result Information + int PQntuples(const PGresult *res) + int PQnfields(const PGresult *res) + char *PQfname(const PGresult *res, int column_number) + int PQfnumber(const PGresult *res, const char *column_name) + Oid PQftable(const PGresult *res, int column_number) + int PQftablecol(const PGresult *res, int column_number) + int PQfformat(const PGresult *res, int column_number) + Oid PQftype(const PGresult *res, int column_number) + int PQfmod(const PGresult *res, int column_number) + int PQfsize(const PGresult *res, int column_number) + int PQbinaryTuples(const PGresult *res) + char *PQgetvalue(const PGresult *res, int row_number, int column_number) + int PQgetisnull(const PGresult *res, int row_number, int column_number) + int PQgetlength(const PGresult *res, int row_number, int column_number) + int PQnparams(const PGresult *res) + Oid PQparamtype(const PGresult *res, int param_number) + # PQprint: pretty useless + + # 33.3.3. Retrieving Other Result Information + char *PQcmdStatus(PGresult *res) + char *PQcmdTuples(PGresult *res) + Oid PQoidValue(const PGresult *res) + + # 33.3.4. Escaping Strings for Inclusion in SQL Commands + char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length) + char *PQescapeLiteral(PGconn *conn, const char *str, size_t length) + size_t PQescapeStringConn(PGconn *conn, + char *to, const char *from_, size_t length, + int *error) + size_t PQescapeString(char *to, const char *from_, size_t length) + unsigned char *PQescapeByteaConn(PGconn *conn, + const unsigned char *src, + size_t from_length, + size_t *to_length) + unsigned char *PQescapeBytea(const unsigned char *src, + size_t from_length, + size_t *to_length) + unsigned char *PQunescapeBytea(const unsigned char *src, size_t *to_length) + + + # 33.4. Asynchronous Command Processing + int PQsendQuery(PGconn *conn, const char *command) nogil + int PQsendQueryParams(PGconn *conn, + const char *command, + int nParams, + const Oid *paramTypes, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + int PQsendPrepare(PGconn *conn, + const char *stmtName, + const char *query, + int nParams, + const Oid *paramTypes) nogil + int PQsendQueryPrepared(PGconn *conn, + const char *stmtName, + int nParams, + const char * const *paramValues, + const int *paramLengths, + const int *paramFormats, + int resultFormat) nogil + int PQsendDescribePrepared(PGconn *conn, const char *stmtName) nogil + int PQsendDescribePortal(PGconn *conn, const char *portalName) nogil + PGresult *PQgetResult(PGconn *conn) nogil + int PQconsumeInput(PGconn *conn) nogil + int PQisBusy(PGconn *conn) nogil + int PQsetnonblocking(PGconn *conn, int arg) nogil + int PQisnonblocking(const PGconn *conn) + int PQflush(PGconn *conn) nogil + + # 33.5. Retrieving Query Results Row-by-Row + int PQsetSingleRowMode(PGconn *conn) + + # 33.6. Canceling Queries in Progress + PGcancel *PQgetCancel(PGconn *conn) + void PQfreeCancel(PGcancel *cancel) + int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize) + + # 33.8. Asynchronous Notification + PGnotify *PQnotifies(PGconn *conn) nogil + + # 33.9. Functions Associated with the COPY Command + int PQputCopyData(PGconn *conn, const char *buffer, int nbytes) nogil + int PQputCopyEnd(PGconn *conn, const char *errormsg) nogil + int PQgetCopyData(PGconn *conn, char **buffer, int async) nogil + + # 33.10. Control Functions + void PQtrace(PGconn *conn, FILE *stream); + void PQsetTraceFlags(PGconn *conn, int flags); + void PQuntrace(PGconn *conn); + + # 33.11. Miscellaneous Functions + void PQfreemem(void *ptr) nogil + void PQconninfoFree(PQconninfoOption *connOptions) + char *PQencryptPasswordConn( + PGconn *conn, const char *passwd, const char *user, const char *algorithm); + PGresult *PQmakeEmptyPGresult(PGconn *conn, ExecStatusType status) + int PQsetResultAttrs(PGresult *res, int numAttributes, PGresAttDesc *attDescs) + int PQlibVersion() + + # 33.12. Notice Processing + ctypedef void (*PQnoticeReceiver)(void *arg, const PGresult *res) + PQnoticeReceiver PQsetNoticeReceiver( + PGconn *conn, PQnoticeReceiver prog, void *arg) + + # 33.18. SSL Support + void PQinitOpenSSL(int do_ssl, int do_crypto) + + # 34.5 Pipeline Mode + + ctypedef enum PGpipelineStatus: + PQ_PIPELINE_OFF + PQ_PIPELINE_ON + PQ_PIPELINE_ABORTED + + PGpipelineStatus PQpipelineStatus(const PGconn *conn) + int PQenterPipelineMode(PGconn *conn) + int PQexitPipelineMode(PGconn *conn) + int PQpipelineSync(PGconn *conn) + int PQsendFlushRequest(PGconn *conn) + +cdef extern from *: + """ +/* Hack to allow the use of old libpq versions */ +#if PG_VERSION_NUM < 100000 +#define PQencryptPasswordConn(conn, passwd, user, algorithm) NULL +#endif + +#if PG_VERSION_NUM < 120000 +#define PQhostaddr(conn) NULL +#endif + +#if PG_VERSION_NUM < 140000 +#define PGRES_PIPELINE_SYNC 10 +#define PGRES_PIPELINE_ABORTED 11 +typedef enum { + PQ_PIPELINE_OFF, + PQ_PIPELINE_ON, + PQ_PIPELINE_ABORTED +} PGpipelineStatus; +#define PQpipelineStatus(conn) PQ_PIPELINE_OFF +#define PQenterPipelineMode(conn) 0 +#define PQexitPipelineMode(conn) 1 +#define PQpipelineSync(conn) 0 +#define PQsendFlushRequest(conn) 0 +#define PQsetTraceFlags(conn, stream) do {} while (0) +#endif +""" diff --git a/psycopg_c/psycopg_c/pq/pgcancel.pyx b/psycopg_c/psycopg_c/pq/pgcancel.pyx new file mode 100644 index 0000000..b7cbb70 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgcancel.pyx @@ -0,0 +1,32 @@ +""" +psycopg_c.pq.PGcancel object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + + +cdef class PGcancel: + def __cinit__(self): + self.pgcancel_ptr = NULL + + @staticmethod + cdef PGcancel _from_ptr(libpq.PGcancel *ptr): + cdef PGcancel rv = PGcancel.__new__(PGcancel) + rv.pgcancel_ptr = ptr + return rv + + def __dealloc__(self) -> None: + self.free() + + def free(self) -> None: + if self.pgcancel_ptr is not NULL: + libpq.PQfreeCancel(self.pgcancel_ptr) + self.pgcancel_ptr = NULL + + def cancel(self) -> None: + cdef char buf[256] + cdef int res = libpq.PQcancel(self.pgcancel_ptr, buf, sizeof(buf)) + if not res: + raise e.OperationalError( + f"cancel failed: {buf.decode('utf8', 'ignore')}" + ) diff --git a/psycopg_c/psycopg_c/pq/pgconn.pyx b/psycopg_c/psycopg_c/pq/pgconn.pyx new file mode 100644 index 0000000..4a60530 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgconn.pyx @@ -0,0 +1,733 @@ +""" +psycopg_c.pq.PGconn object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cdef extern from * nogil: + """ +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + /* We don't need a real definition for this because Windows is not affected + * by the issue caused by closing the fds after fork. + */ + #define getpid() (0) +#else + #include <unistd.h> +#endif + """ + pid_t getpid() + +from libc.stdio cimport fdopen +from cpython.mem cimport PyMem_Malloc, PyMem_Free +from cpython.bytes cimport PyBytes_AsString +from cpython.memoryview cimport PyMemoryView_FromObject + +import sys + +from psycopg.pq import Format as PqFormat, Trace +from psycopg.pq.misc import PGnotify, connection_summary +from psycopg_c.pq cimport PQBuffer + + +cdef class PGconn: + @staticmethod + cdef PGconn _from_ptr(libpq.PGconn *ptr): + cdef PGconn rv = PGconn.__new__(PGconn) + rv._pgconn_ptr = ptr + + libpq.PQsetNoticeReceiver(ptr, notice_receiver, <void *>rv) + return rv + + def __cinit__(self): + self._pgconn_ptr = NULL + self._procpid = getpid() + + def __dealloc__(self): + # Close the connection only if it was created in this process, + # not if this object is being GC'd after fork. + if self._procpid == getpid(): + self.finish() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self) + return f"<{cls} {info} at 0x{id(self):x}>" + + @classmethod + def connect(cls, const char *conninfo) -> PGconn: + cdef libpq.PGconn* pgconn = libpq.PQconnectdb(conninfo) + if not pgconn: + raise MemoryError("couldn't allocate PGconn") + + return PGconn._from_ptr(pgconn) + + @classmethod + def connect_start(cls, const char *conninfo) -> PGconn: + cdef libpq.PGconn* pgconn = libpq.PQconnectStart(conninfo) + if not pgconn: + raise MemoryError("couldn't allocate PGconn") + + return PGconn._from_ptr(pgconn) + + def connect_poll(self) -> int: + return _call_int(self, <conn_int_f>libpq.PQconnectPoll) + + def finish(self) -> None: + if self._pgconn_ptr is not NULL: + libpq.PQfinish(self._pgconn_ptr) + self._pgconn_ptr = NULL + + @property + def pgconn_ptr(self) -> Optional[int]: + if self._pgconn_ptr: + return <long long><void *>self._pgconn_ptr + else: + return None + + @property + def info(self) -> List["ConninfoOption"]: + _ensure_pgconn(self) + cdef libpq.PQconninfoOption *opts = libpq.PQconninfo(self._pgconn_ptr) + if opts is NULL: + raise MemoryError("couldn't allocate connection info") + rv = _options_from_array(opts) + libpq.PQconninfoFree(opts) + return rv + + def reset(self) -> None: + _ensure_pgconn(self) + libpq.PQreset(self._pgconn_ptr) + + def reset_start(self) -> None: + if not libpq.PQresetStart(self._pgconn_ptr): + raise e.OperationalError("couldn't reset connection") + + def reset_poll(self) -> int: + return _call_int(self, <conn_int_f>libpq.PQresetPoll) + + @classmethod + def ping(self, const char *conninfo) -> int: + return libpq.PQping(conninfo) + + @property + def db(self) -> bytes: + return _call_bytes(self, libpq.PQdb) + + @property + def user(self) -> bytes: + return _call_bytes(self, libpq.PQuser) + + @property + def password(self) -> bytes: + return _call_bytes(self, libpq.PQpass) + + @property + def host(self) -> bytes: + return _call_bytes(self, libpq.PQhost) + + @property + def hostaddr(self) -> bytes: + if libpq.PG_VERSION_NUM < 120000: + raise e.NotSupportedError( + f"PQhostaddr requires libpq from PostgreSQL 12," + f" {libpq.PG_VERSION_NUM} available instead" + ) + + _ensure_pgconn(self) + cdef char *rv = libpq.PQhostaddr(self._pgconn_ptr) + assert rv is not NULL + return rv + + @property + def port(self) -> bytes: + return _call_bytes(self, libpq.PQport) + + @property + def tty(self) -> bytes: + return _call_bytes(self, libpq.PQtty) + + @property + def options(self) -> bytes: + return _call_bytes(self, libpq.PQoptions) + + @property + def status(self) -> int: + return libpq.PQstatus(self._pgconn_ptr) + + @property + def transaction_status(self) -> int: + return libpq.PQtransactionStatus(self._pgconn_ptr) + + def parameter_status(self, const char *name) -> Optional[bytes]: + _ensure_pgconn(self) + cdef const char *rv = libpq.PQparameterStatus(self._pgconn_ptr, name) + if rv is not NULL: + return rv + else: + return None + + @property + def error_message(self) -> bytes: + return libpq.PQerrorMessage(self._pgconn_ptr) + + @property + def protocol_version(self) -> int: + return _call_int(self, libpq.PQprotocolVersion) + + @property + def server_version(self) -> int: + return _call_int(self, libpq.PQserverVersion) + + @property + def socket(self) -> int: + rv = _call_int(self, libpq.PQsocket) + if rv == -1: + raise e.OperationalError("the connection is lost") + return rv + + @property + def backend_pid(self) -> int: + return _call_int(self, libpq.PQbackendPID) + + @property + def needs_password(self) -> bool: + return bool(libpq.PQconnectionNeedsPassword(self._pgconn_ptr)) + + @property + def used_password(self) -> bool: + return bool(libpq.PQconnectionUsedPassword(self._pgconn_ptr)) + + @property + def ssl_in_use(self) -> bool: + return bool(_call_int(self, <conn_int_f>libpq.PQsslInUse)) + + def exec_(self, const char *command) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *pgresult + with nogil: + pgresult = libpq.PQexec(self._pgconn_ptr, command) + if pgresult is NULL: + raise MemoryError("couldn't allocate PGresult") + + return PGresult._from_ptr(pgresult) + + def send_query(self, const char *command) -> None: + _ensure_pgconn(self) + cdef int rv + with nogil: + rv = libpq.PQsendQuery(self._pgconn_ptr, command) + if not rv: + raise e.OperationalError(f"sending query failed: {error_message(self)}") + + def exec_params( + self, + const char *command, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> PGresult: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, param_types, param_formats) + + cdef libpq.PGresult *pgresult + with nogil: + pgresult = libpq.PQexecParams( + self._pgconn_ptr, command, <int>cnparams, ctypes, + <const char *const *>cvalues, clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if pgresult is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(pgresult) + + def send_query_params( + self, + const char *command, + param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> None: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, param_types, param_formats) + + cdef int rv + with nogil: + rv = libpq.PQsendQueryParams( + self._pgconn_ptr, command, <int>cnparams, ctypes, + <const char *const *>cvalues, clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if not rv: + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_prepare( + self, + const char *name, + const char *command, + param_types: Optional[Sequence[int]] = None, + ) -> None: + _ensure_pgconn(self) + + cdef int i + cdef Py_ssize_t nparams = len(param_types) if param_types else 0 + cdef libpq.Oid *atypes = NULL + if nparams: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = param_types[i] + + cdef int rv + with nogil: + rv = libpq.PQsendPrepare( + self._pgconn_ptr, name, command, <int>nparams, atypes + ) + PyMem_Free(atypes) + if not rv: + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_query_prepared( + self, + const char *name, + param_values: Optional[Sequence[Optional[bytes]]], + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> None: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, None, param_formats) + + cdef int rv + with nogil: + rv = libpq.PQsendQueryPrepared( + self._pgconn_ptr, name, <int>cnparams, <const char *const *>cvalues, + clengths, cformats, result_format) + _clear_query_params(ctypes, cvalues, clengths, cformats) + if not rv: + raise e.OperationalError( + f"sending prepared query failed: {error_message(self)}" + ) + + def prepare( + self, + const char *name, + const char *command, + param_types: Optional[Sequence[int]] = None, + ) -> PGresult: + _ensure_pgconn(self) + + cdef int i + cdef Py_ssize_t nparams = len(param_types) if param_types else 0 + cdef libpq.Oid *atypes = NULL + if nparams: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = param_types[i] + + cdef libpq.PGresult *rv + with nogil: + rv = libpq.PQprepare( + self._pgconn_ptr, name, command, <int>nparams, atypes) + PyMem_Free(atypes) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def exec_prepared( + self, + const char *name, + param_values: Optional[Sequence[bytes]], + param_formats: Optional[Sequence[int]] = None, + int result_format = PqFormat.TEXT, + ) -> PGresult: + _ensure_pgconn(self) + + cdef Py_ssize_t cnparams + cdef libpq.Oid *ctypes + cdef char *const *cvalues + cdef int *clengths + cdef int *cformats + cnparams, ctypes, cvalues, clengths, cformats = _query_params_args( + param_values, None, param_formats) + + cdef libpq.PGresult *rv + with nogil: + rv = libpq.PQexecPrepared( + self._pgconn_ptr, name, <int>cnparams, + <const char *const *>cvalues, + clengths, cformats, result_format) + + _clear_query_params(ctypes, cvalues, clengths, cformats) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def describe_prepared(self, const char *name) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *rv = libpq.PQdescribePrepared(self._pgconn_ptr, name) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def send_describe_prepared(self, const char *name) -> None: + _ensure_pgconn(self) + cdef int rv = libpq.PQsendDescribePrepared(self._pgconn_ptr, name) + if not rv: + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def describe_portal(self, const char *name) -> PGresult: + _ensure_pgconn(self) + cdef libpq.PGresult *rv = libpq.PQdescribePortal(self._pgconn_ptr, name) + if rv is NULL: + raise MemoryError("couldn't allocate PGresult") + return PGresult._from_ptr(rv) + + def send_describe_portal(self, const char *name) -> None: + _ensure_pgconn(self) + cdef int rv = libpq.PQsendDescribePortal(self._pgconn_ptr, name) + if not rv: + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def get_result(self) -> Optional["PGresult"]: + cdef libpq.PGresult *pgresult = libpq.PQgetResult(self._pgconn_ptr) + if pgresult is NULL: + return None + return PGresult._from_ptr(pgresult) + + def consume_input(self) -> None: + if 1 != libpq.PQconsumeInput(self._pgconn_ptr): + raise e.OperationalError(f"consuming input failed: {error_message(self)}") + + def is_busy(self) -> int: + cdef int rv + with nogil: + rv = libpq.PQisBusy(self._pgconn_ptr) + return rv + + @property + def nonblocking(self) -> int: + return libpq.PQisnonblocking(self._pgconn_ptr) + + @nonblocking.setter + def nonblocking(self, int arg) -> None: + if 0 > libpq.PQsetnonblocking(self._pgconn_ptr, arg): + raise e.OperationalError(f"setting nonblocking failed: {error_message(self)}") + + cpdef int flush(self) except -1: + if self._pgconn_ptr == NULL: + raise e.OperationalError(f"flushing failed: the connection is closed") + cdef int rv = libpq.PQflush(self._pgconn_ptr) + if rv < 0: + raise e.OperationalError(f"flushing failed: {error_message(self)}") + return rv + + def set_single_row_mode(self) -> None: + if not libpq.PQsetSingleRowMode(self._pgconn_ptr): + raise e.OperationalError("setting single row mode failed") + + def get_cancel(self) -> PGcancel: + cdef libpq.PGcancel *ptr = libpq.PQgetCancel(self._pgconn_ptr) + if not ptr: + raise e.OperationalError("couldn't create cancel object") + return PGcancel._from_ptr(ptr) + + cpdef object notifies(self): + cdef libpq.PGnotify *ptr + with nogil: + ptr = libpq.PQnotifies(self._pgconn_ptr) + if ptr: + ret = PGnotify(ptr.relname, ptr.be_pid, ptr.extra) + libpq.PQfreemem(ptr) + return ret + else: + return None + + def put_copy_data(self, buffer) -> int: + cdef int rv + cdef char *cbuffer + cdef Py_ssize_t length + + _buffer_as_string_and_size(buffer, &cbuffer, &length) + rv = libpq.PQputCopyData(self._pgconn_ptr, cbuffer, <int>length) + if rv < 0: + raise e.OperationalError(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + cdef int rv + cdef const char *cerr = NULL + if error is not None: + cerr = PyBytes_AsString(error) + rv = libpq.PQputCopyEnd(self._pgconn_ptr, cerr) + if rv < 0: + raise e.OperationalError(f"sending copy end failed: {error_message(self)}") + return rv + + def get_copy_data(self, int async_) -> Tuple[int, memoryview]: + cdef char *buffer_ptr = NULL + cdef int nbytes + nbytes = libpq.PQgetCopyData(self._pgconn_ptr, &buffer_ptr, async_) + if nbytes == -2: + raise e.OperationalError(f"receiving copy data failed: {error_message(self)}") + if buffer_ptr is not NULL: + data = PyMemoryView_FromObject( + PQBuffer._from_buffer(<unsigned char *>buffer_ptr, nbytes)) + return nbytes, data + else: + return nbytes, b"" # won't parse it, doesn't really be memoryview + + def trace(self, fileno: int) -> None: + if sys.platform != "linux": + raise e.NotSupportedError("currently only supported on Linux") + stream = fdopen(fileno, b"w") + libpq.PQtrace(self._pgconn_ptr, stream) + + def set_trace_flags(self, flags: Trace) -> None: + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQsetTraceFlags requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + libpq.PQsetTraceFlags(self._pgconn_ptr, flags) + + def untrace(self) -> None: + libpq.PQuntrace(self._pgconn_ptr) + + def encrypt_password( + self, const char *passwd, const char *user, algorithm = None + ) -> bytes: + if libpq.PG_VERSION_NUM < 100000: + raise e.NotSupportedError( + f"PQencryptPasswordConn requires libpq from PostgreSQL 10," + f" {libpq.PG_VERSION_NUM} available instead" + ) + + cdef char *out + cdef const char *calgo = NULL + if algorithm: + calgo = algorithm + out = libpq.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, calgo) + if not out: + raise e.OperationalError( + f"password encryption failed: {error_message(self)}" + ) + + rv = bytes(out) + libpq.PQfreemem(out) + return rv + + def make_empty_result(self, int exec_status) -> PGresult: + cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult( + self._pgconn_ptr, <libpq.ExecStatusType>exec_status) + if not rv: + raise MemoryError("couldn't allocate empty PGresult") + return PGresult._from_ptr(rv) + + @property + def pipeline_status(self) -> int: + """The current pipeline mode status. + + For libpq < 14.0, always return 0 (PQ_PIPELINE_OFF). + """ + if libpq.PG_VERSION_NUM < 140000: + return libpq.PQ_PIPELINE_OFF + cdef int status = libpq.PQpipelineStatus(self._pgconn_ptr) + return status + + def enter_pipeline_mode(self) -> None: + """Enter pipeline mode. + + :raises ~e.OperationalError: in case of failure to enter the pipeline + mode. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQenterPipelineMode requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + if libpq.PQenterPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError("failed to enter pipeline mode") + + def exit_pipeline_mode(self) -> None: + """Exit pipeline mode. + + :raises ~e.OperationalError: in case of failure to exit the pipeline + mode. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQexitPipelineMode requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + if libpq.PQexitPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError(error_message(self)) + + def pipeline_sync(self) -> None: + """Mark a synchronization point in a pipeline. + + :raises ~e.OperationalError: if the connection is not in pipeline mode + or if sync failed. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQpipelineSync requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + rv = libpq.PQpipelineSync(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError("connection not in pipeline mode") + if rv != 1: + raise e.OperationalError("failed to sync pipeline") + + def send_flush_request(self) -> None: + """Sends a request for the server to flush its output buffer. + + :raises ~e.OperationalError: if the flush request failed. + """ + if libpq.PG_VERSION_NUM < 140000: + raise e.NotSupportedError( + f"PQsendFlushRequest requires libpq from PostgreSQL 14," + f" {libpq.PG_VERSION_NUM} available instead" + ) + cdef int rv = libpq.PQsendFlushRequest(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError(f"flush request failed: {error_message(self)}") + + +cdef int _ensure_pgconn(PGconn pgconn) except 0: + if pgconn._pgconn_ptr is not NULL: + return 1 + + raise e.OperationalError("the connection is closed") + + +cdef char *_call_bytes(PGconn pgconn, conn_bytes_f func) except NULL: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not _ensure_pgconn(pgconn): + return NULL + cdef char *rv = func(pgconn._pgconn_ptr) + assert rv is not NULL + return rv + + +cdef int _call_int(PGconn pgconn, conn_int_f func) except -2: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not _ensure_pgconn(pgconn): + return -2 + return func(pgconn._pgconn_ptr) + + +cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil: + cdef PGconn pgconn = <object>arg + if pgconn.notice_handler is None: + return + + cdef PGresult res = PGresult._from_ptr(<libpq.PGresult *>res_ptr) + try: + pgconn.notice_handler(res) + except Exception as e: + logger.exception("error in notice receiver: %s", e) + finally: + res._pgresult_ptr = NULL # avoid destroying the pgresult_ptr + + +cdef (Py_ssize_t, libpq.Oid *, char * const*, int *, int *) _query_params_args( + list param_values: Optional[Sequence[Optional[bytes]]], + param_types: Optional[Sequence[int]], + list param_formats: Optional[Sequence[int]], +) except *: + cdef int i + + # the PostgresQuery converts the param_types to tuple, so this operation + # is most often no-op + cdef tuple tparam_types + if param_types is not None and not isinstance(param_types, tuple): + tparam_types = tuple(param_types) + else: + tparam_types = param_types + + cdef Py_ssize_t nparams = len(param_values) if param_values else 0 + if tparam_types is not None and len(tparam_types) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(tparam_types)) + ) + if param_formats is not None and len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_formats" + % (nparams, len(param_formats)) + ) + + cdef char **aparams = NULL + cdef int *alenghts = NULL + cdef char *ptr + cdef Py_ssize_t length + + if nparams: + aparams = <char **>PyMem_Malloc(nparams * sizeof(char *)) + alenghts = <int *>PyMem_Malloc(nparams * sizeof(int)) + for i in range(nparams): + obj = param_values[i] + if obj is None: + aparams[i] = NULL + alenghts[i] = 0 + else: + # TODO: it is a leak if this fails (but it should only fail + # on internal error, e.g. if obj is not a buffer) + _buffer_as_string_and_size(obj, &ptr, &length) + aparams[i] = ptr + alenghts[i] = <int>length + + cdef libpq.Oid *atypes = NULL + if tparam_types: + atypes = <libpq.Oid *>PyMem_Malloc(nparams * sizeof(libpq.Oid)) + for i in range(nparams): + atypes[i] = tparam_types[i] + + cdef int *aformats = NULL + if param_formats is not None: + aformats = <int *>PyMem_Malloc(nparams * sizeof(int *)) + for i in range(nparams): + aformats[i] = param_formats[i] + + return (nparams, atypes, aparams, alenghts, aformats) + + +cdef void _clear_query_params( + libpq.Oid *ctypes, char *const *cvalues, int *clenghst, int *cformats +): + PyMem_Free(ctypes) + PyMem_Free(<char **>cvalues) + PyMem_Free(clenghst) + PyMem_Free(cformats) diff --git a/psycopg_c/psycopg_c/pq/pgresult.pyx b/psycopg_c/psycopg_c/pq/pgresult.pyx new file mode 100644 index 0000000..6df42e8 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pgresult.pyx @@ -0,0 +1,157 @@ +""" +psycopg_c.pq.PGresult object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.mem cimport PyMem_Malloc, PyMem_Free + +from psycopg.pq.misc import PGresAttDesc +from psycopg.pq._enums import ExecStatus + + +@cython.freelist(8) +cdef class PGresult: + def __cinit__(self): + self._pgresult_ptr = NULL + + @staticmethod + cdef PGresult _from_ptr(libpq.PGresult *ptr): + cdef PGresult rv = PGresult.__new__(PGresult) + rv._pgresult_ptr = ptr + return rv + + def __dealloc__(self) -> None: + self.clear() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + status = ExecStatus(self.status) + return f"<{cls} [{status.name}] at 0x{id(self):x}>" + + def clear(self) -> None: + if self._pgresult_ptr is not NULL: + libpq.PQclear(self._pgresult_ptr) + self._pgresult_ptr = NULL + + @property + def pgresult_ptr(self) -> Optional[int]: + if self._pgresult_ptr: + return <long long><void *>self._pgresult_ptr + else: + return None + + @property + def status(self) -> int: + return libpq.PQresultStatus(self._pgresult_ptr) + + @property + def error_message(self) -> bytes: + return libpq.PQresultErrorMessage(self._pgresult_ptr) + + def error_field(self, int fieldcode) -> Optional[bytes]: + cdef char * rv = libpq.PQresultErrorField(self._pgresult_ptr, fieldcode) + if rv is not NULL: + return rv + else: + return None + + @property + def ntuples(self) -> int: + return libpq.PQntuples(self._pgresult_ptr) + + @property + def nfields(self) -> int: + return libpq.PQnfields(self._pgresult_ptr) + + def fname(self, int column_number) -> Optional[bytes]: + cdef char *rv = libpq.PQfname(self._pgresult_ptr, column_number) + if rv is not NULL: + return rv + else: + return None + + def ftable(self, int column_number) -> int: + return libpq.PQftable(self._pgresult_ptr, column_number) + + def ftablecol(self, int column_number) -> int: + return libpq.PQftablecol(self._pgresult_ptr, column_number) + + def fformat(self, int column_number) -> int: + return libpq.PQfformat(self._pgresult_ptr, column_number) + + def ftype(self, int column_number) -> int: + return libpq.PQftype(self._pgresult_ptr, column_number) + + def fmod(self, int column_number) -> int: + return libpq.PQfmod(self._pgresult_ptr, column_number) + + def fsize(self, int column_number) -> int: + return libpq.PQfsize(self._pgresult_ptr, column_number) + + @property + def binary_tuples(self) -> int: + return libpq.PQbinaryTuples(self._pgresult_ptr) + + def get_value(self, int row_number, int column_number) -> Optional[bytes]: + cdef int crow = row_number + cdef int ccol = column_number + cdef int length = libpq.PQgetlength(self._pgresult_ptr, crow, ccol) + cdef char *v + if length: + v = libpq.PQgetvalue(self._pgresult_ptr, crow, ccol) + # TODO: avoid copy + return v[:length] + else: + if libpq.PQgetisnull(self._pgresult_ptr, crow, ccol): + return None + else: + return b"" + + @property + def nparams(self) -> int: + return libpq.PQnparams(self._pgresult_ptr) + + def param_type(self, int param_number) -> int: + return libpq.PQparamtype(self._pgresult_ptr, param_number) + + @property + def command_status(self) -> Optional[bytes]: + cdef char *rv = libpq.PQcmdStatus(self._pgresult_ptr) + if rv is not NULL: + return rv + else: + return None + + @property + def command_tuples(self) -> Optional[int]: + cdef char *rv = libpq.PQcmdTuples(self._pgresult_ptr) + if rv is NULL: + return None + cdef bytes brv = rv + return int(brv) if brv else None + + @property + def oid_value(self) -> int: + return libpq.PQoidValue(self._pgresult_ptr) + + def set_attributes(self, descriptions: List[PGresAttDesc]): + cdef Py_ssize_t num = len(descriptions) + cdef libpq.PGresAttDesc *attrs = <libpq.PGresAttDesc *>PyMem_Malloc( + num * sizeof(libpq.PGresAttDesc)) + + for i in range(num): + descr = descriptions[i] + attrs[i].name = descr.name + attrs[i].tableid = descr.tableid + attrs[i].columnid = descr.columnid + attrs[i].format = descr.format + attrs[i].typid = descr.typid + attrs[i].typlen = descr.typlen + attrs[i].atttypmod = descr.atttypmod + + cdef int res = libpq.PQsetResultAttrs(self._pgresult_ptr, <int>num, attrs) + PyMem_Free(attrs) + if (res == 0): + raise e.OperationalError("PQsetResultAttrs failed") diff --git a/psycopg_c/psycopg_c/pq/pqbuffer.pyx b/psycopg_c/psycopg_c/pq/pqbuffer.pyx new file mode 100644 index 0000000..eb5d648 --- /dev/null +++ b/psycopg_c/psycopg_c/pq/pqbuffer.pyx @@ -0,0 +1,111 @@ +""" +PQbuffer object implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython +from cpython.bytes cimport PyBytes_AsStringAndSize +from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE +from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release + + +@cython.freelist(32) +cdef class PQBuffer: + """ + Wrap a chunk of memory allocated by the libpq and expose it as memoryview. + """ + @staticmethod + cdef PQBuffer _from_buffer(unsigned char *buf, Py_ssize_t length): + cdef PQBuffer rv = PQBuffer.__new__(PQBuffer) + rv.buf = buf + rv.len = length + return rv + + def __cinit__(self): + self.buf = NULL + self.len = 0 + + def __dealloc__(self): + if self.buf: + libpq.PQfreemem(self.buf) + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__qualname__}" + f"({bytes(self)})" + ) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = self.buf + buffer.obj = self + buffer.len = self.len + buffer.itemsize = sizeof(unsigned char) + buffer.readonly = 1 + buffer.ndim = 1 + buffer.format = NULL # unsigned char + buffer.shape = &self.len + buffer.strides = NULL + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + +@cython.freelist(32) +cdef class ViewBuffer: + """ + Wrap a chunk of memory owned by a different object. + """ + @staticmethod + cdef ViewBuffer _from_buffer( + object obj, unsigned char *buf, Py_ssize_t length + ): + cdef ViewBuffer rv = ViewBuffer.__new__(ViewBuffer) + rv.obj = obj + rv.buf = buf + rv.len = length + return rv + + def __cinit__(self): + self.buf = NULL + self.len = 0 + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__qualname__}" + f"({bytes(self)})" + ) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = self.buf + buffer.obj = self + buffer.len = self.len + buffer.itemsize = sizeof(unsigned char) + buffer.readonly = 1 + buffer.ndim = 1 + buffer.format = NULL # unsigned char + buffer.shape = &self.len + buffer.strides = NULL + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + +cdef int _buffer_as_string_and_size( + data: "Buffer", char **ptr, Py_ssize_t *length +) except -1: + cdef Py_buffer buf + + if isinstance(data, bytes): + PyBytes_AsStringAndSize(data, ptr, length) + elif PyObject_CheckBuffer(data): + PyObject_GetBuffer(data, &buf, PyBUF_SIMPLE) + ptr[0] = <char *>buf.buf + length[0] = buf.len + PyBuffer_Release(&buf) + else: + raise TypeError(f"bytes or buffer expected, got {type(data)}") diff --git a/psycopg_c/psycopg_c/py.typed b/psycopg_c/psycopg_c/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg_c/psycopg_c/py.typed diff --git a/psycopg_c/psycopg_c/types/array.pyx b/psycopg_c/psycopg_c/types/array.pyx new file mode 100644 index 0000000..9abaef9 --- /dev/null +++ b/psycopg_c/psycopg_c/types/array.pyx @@ -0,0 +1,276 @@ +""" +C optimized functions to manipulate arrays +""" + +# Copyright (C) 2022 The Psycopg Team + +import cython + +from libc.stdint cimport int32_t, uint32_t +from libc.string cimport memset, strchr +from cpython.mem cimport PyMem_Realloc, PyMem_Free +from cpython.ref cimport Py_INCREF +from cpython.list cimport PyList_New,PyList_Append, PyList_GetSlice +from cpython.list cimport PyList_GET_ITEM, PyList_SET_ITEM, PyList_GET_SIZE +from cpython.object cimport PyObject + +from psycopg_c.pq cimport _buffer_as_string_and_size +from psycopg_c.pq.libpq cimport Oid +from psycopg_c._psycopg cimport endian + +from psycopg import errors as e + +cdef extern from *: + """ +/* Defined in PostgreSQL in src/include/utils/array.h */ +#define MAXDIM 6 + """ + const int MAXDIM + + +cdef class ArrayLoader(_CRecursiveLoader): + + format = PQ_TEXT + base_oid = 0 + delimiter = b"," + + cdef PyObject *row_loader + cdef char cdelim + + # A memory area which used to unescape elements. + # Keep it here to avoid a malloc per element and to set up exceptions + # to make sure to free it on error. + cdef char *scratch + cdef size_t sclen + + cdef object cload(self, const char *data, size_t length): + if self.cdelim == b"\x00": + self.row_loader = self._tx._c_get_loader( + <PyObject *>self.base_oid, <PyObject *>PQ_TEXT) + self.cdelim = self.delimiter[0] + + return _array_load_text( + data, length, self.row_loader, self.cdelim, + &(self.scratch), &(self.sclen)) + + def __dealloc__(self): + PyMem_Free(self.scratch) + + +@cython.final +cdef class ArrayBinaryLoader(_CRecursiveLoader): + + format = PQ_BINARY + + cdef PyObject *row_loader + + cdef object cload(self, const char *data, size_t length): + rv = _array_load_binary(data, length, self._tx, &(self.row_loader)) + return rv + + +cdef object _array_load_text( + const char *buf, size_t length, PyObject *row_loader, char cdelim, + char **scratch, size_t *sclen +): + if length == 0: + raise e.DataError("malformed array: empty data") + + cdef const char *end = buf + length + + # Remove the dimensions information prefix (``[...]=``) + if buf[0] == b"[": + buf = strchr(buf + 1, b'=') + if buf == NULL: + raise e.DataError("malformed array: no '=' after dimension information") + buf += 1 + + # TODO: further optimization: pre-scan the array to find the array + # dimensions, so that we can preallocate the list sized instead of calling + # append, which is the dominating operation + + cdef list stack = [] + cdef list a = [] + rv = a + cdef PyObject *tmp + + cdef CLoader cloader = None + cdef object pyload = None + if (<RowLoader>row_loader).cloader is not None: + cloader = (<RowLoader>row_loader).cloader + else: + pyload = (<RowLoader>row_loader).loadfunc + + while buf < end: + if buf[0] == b'{': + if stack: + tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1) + PyList_Append(<object>tmp, a) + PyList_Append(stack, a) + a = [] + buf += 1 + + elif buf[0] == b'}': + if not stack: + raise e.DataError("malformed array: unexpected '}'") + rv = stack.pop() + buf += 1 + + elif buf[0] == cdelim: + buf += 1 + + else: + v = _parse_token( + &buf, end, cdelim, scratch, sclen, cloader, pyload) + if not stack: + raise e.DataError("malformed array: missing initial '{'") + tmp = PyList_GET_ITEM(stack, PyList_GET_SIZE(stack) - 1) + PyList_Append(<object>tmp, v) + + return rv + + +cdef object _parse_token( + const char **bufptr, const char *bufend, char cdelim, + char **scratch, size_t *sclen, CLoader cloader, object load +): + cdef const char *start = bufptr[0] + cdef int has_quotes = start[0] == b'"' + cdef int quoted = has_quotes + cdef int num_escapes = 0 + cdef int escaped = 0 + + if has_quotes: + start += 1 + cdef const char *end = start + + while end < bufend: + if (end[0] == cdelim or end[0] == b'}') and not quoted: + break + elif end[0] == b'\\' and not escaped: + num_escapes += 1 + escaped = 1 + end += 1 + continue + elif end[0] == b'"' and not escaped: + quoted = 0 + escaped = 0 + end += 1 + else: + raise e.DataError("malformed array: hit the end of the buffer") + + # Return the new position for the buffer + bufptr[0] = end + if has_quotes: + end -= 1 + + cdef int length = (end - start) + if length == 4 and not has_quotes \ + and start[0] == b'N' and start[1] == b'U' \ + and start[2] == b'L' and start[3] == b'L': + return None + + cdef const char *src + cdef char *tgt + cdef size_t unesclen + + if not num_escapes: + if cloader is not None: + return cloader.cload(start, length) + else: + b = start[:length] + return load(b) + + else: + unesclen = length - num_escapes + 1 + if unesclen > sclen[0]: + scratch[0] = <char *>PyMem_Realloc(scratch[0], unesclen) + sclen[0] = unesclen + + src = start + tgt = scratch[0] + while src < end: + if src[0] == b'\\': + src += 1 + tgt[0] = src[0] + src += 1 + tgt += 1 + + tgt[0] = b'\x00' + + if cloader is not None: + return cloader.cload(scratch[0], length - num_escapes) + else: + b = scratch[0][:length - num_escapes] + return load(b) + + +@cython.cdivision(True) +cdef object _array_load_binary( + const char *buf, size_t length, Transformer tx, PyObject **row_loader_ptr +): + # head is ndims, hasnull, elem oid + cdef uint32_t *buf32 = <uint32_t *>buf + cdef int ndims = endian.be32toh(buf32[0]) + + if ndims <= 0: + return [] + elif ndims > MAXDIM: + raise e.DataError( + r"unexpected number of dimensions %s exceeding the maximum allowed %s" + % (ndims, MAXDIM) + ) + + cdef object oid + if row_loader_ptr[0] == NULL: + oid = <Oid>endian.be32toh(buf32[2]) + row_loader_ptr[0] = tx._c_get_loader(<PyObject *>oid, <PyObject *>PQ_BINARY) + + cdef Py_ssize_t[MAXDIM] dims + cdef int i + for i in range(ndims): + # Every dimension is dim, lower bound + dims[i] = endian.be32toh(buf32[3 + 2 * i]) + + buf += (3 + 2 * ndims) * sizeof(uint32_t) + out = _array_load_binary_rec(ndims, dims, &buf, row_loader_ptr[0]) + return out + + +cdef object _array_load_binary_rec( + Py_ssize_t ndims, Py_ssize_t *dims, const char **bufptr, PyObject *row_loader +): + cdef const char *buf + cdef int i + cdef int32_t size + cdef object val + + cdef Py_ssize_t nelems = dims[0] + cdef list out = PyList_New(nelems) + + if ndims == 1: + buf = bufptr[0] + for i in range(nelems): + size = <int32_t>endian.be32toh((<uint32_t *>buf)[0]) + buf += sizeof(uint32_t) + if size == -1: + val = None + else: + if (<RowLoader>row_loader).cloader is not None: + val = (<RowLoader>row_loader).cloader.cload(buf, size) + else: + val = (<RowLoader>row_loader).loadfunc(buf[:size]) + buf += size + + Py_INCREF(val) + PyList_SET_ITEM(out, i, val) + + bufptr[0] = buf + + else: + for i in range(nelems): + val = _array_load_binary_rec(ndims - 1, dims + 1, bufptr, row_loader) + Py_INCREF(val) + PyList_SET_ITEM(out, i, val) + + return out diff --git a/psycopg_c/psycopg_c/types/bool.pyx b/psycopg_c/psycopg_c/types/bool.pyx new file mode 100644 index 0000000..86cf88e --- /dev/null +++ b/psycopg_c/psycopg_c/types/bool.pyx @@ -0,0 +1,78 @@ +""" +Cython adapters for boolean. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + + +@cython.final +cdef class BoolDumper(CDumper): + + format = PQ_TEXT + oid = oids.BOOL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *buf = CDumper.ensure_size(rv, offset, 1) + + # Fast paths, just a pointer comparison + if obj is True: + buf[0] = b"t" + elif obj is False: + buf[0] = b"f" + elif obj: + buf[0] = b"t" + else: + buf[0] = b"f" + + return 1 + + def quote(self, obj: bool) -> bytes: + if obj is True: + return b"true" + elif obj is False: + return b"false" + else: + return b"true" if obj else b"false" + + +@cython.final +cdef class BoolBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.BOOL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *buf = CDumper.ensure_size(rv, offset, 1) + + # Fast paths, just a pointer comparison + if obj is True: + buf[0] = b"\x01" + elif obj is False: + buf[0] = b"\x00" + elif obj: + buf[0] = b"\x01" + else: + buf[0] = b"\x00" + + return 1 + + +@cython.final +cdef class BoolLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + # this creates better C than `return data[0] == b't'` + return True if data[0] == b't' else False + + +@cython.final +cdef class BoolBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return True if data[0] else False diff --git a/psycopg_c/psycopg_c/types/datetime.pyx b/psycopg_c/psycopg_c/types/datetime.pyx new file mode 100644 index 0000000..51e7dcf --- /dev/null +++ b/psycopg_c/psycopg_c/types/datetime.pyx @@ -0,0 +1,1136 @@ +""" +Cython adapters for date/time types. +""" + +# Copyright (C) 2021 The Psycopg Team + +from libc.string cimport memset, strchr +from cpython cimport datetime as cdt +from cpython.dict cimport PyDict_GetItem +from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs + +cdef extern from "Python.h": + const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL + object PyTimeZone_FromOffset(object offset) + +cdef extern from *: + """ +/* Multipliers from fraction of seconds to microseconds */ +static int _uspad[] = {0, 100000, 10000, 1000, 100, 10, 1}; + """ + cdef int *_uspad + +from datetime import date, time, timedelta, datetime, timezone + +from psycopg_c._psycopg cimport endian + +from psycopg import errors as e +from psycopg._compat import ZoneInfo + + +# Initialise the datetime C API +cdt.import_datetime() + +cdef enum: + ORDER_YMD = 0 + ORDER_DMY = 1 + ORDER_MDY = 2 + ORDER_PGDM = 3 + ORDER_PGMD = 4 + +cdef enum: + INTERVALSTYLE_OTHERS = 0 + INTERVALSTYLE_SQL_STANDARD = 1 + INTERVALSTYLE_POSTGRES = 2 + +cdef enum: + PG_DATE_EPOCH_DAYS = 730120 # date(2000, 1, 1).toordinal() + PY_DATE_MIN_DAYS = 1 # date.min.toordinal() + +cdef object date_toordinal = date.toordinal +cdef object date_fromordinal = date.fromordinal +cdef object datetime_astimezone = datetime.astimezone +cdef object time_utcoffset = time.utcoffset +cdef object timedelta_total_seconds = timedelta.total_seconds +cdef object timezone_utc = timezone.utc +cdef object pg_datetime_epoch = datetime(2000, 1, 1) +cdef object pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc) + +cdef object _month_abbr = { + n: i + for i, n in enumerate( + b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1 + ) +} + + +@cython.final +cdef class DateDumper(CDumper): + + format = PQ_TEXT + oid = oids.DATE_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class DateBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.DATE_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int32_t days = PyObject_CallFunctionObjArgs( + date_toordinal, <PyObject *>obj, NULL) + days -= PG_DATE_EPOCH_DAYS + cdef int32_t *buf = <int32_t *>CDumper.ensure_size( + rv, offset, sizeof(int32_t)) + buf[0] = endian.htobe32(days) + return sizeof(int32_t) + + +cdef class _BaseTimeDumper(CDumper): + + cpdef get_key(self, obj, format): + # Use (cls,) to report the need to upgrade to a dumper for timetz (the + # Frankenstein of the data types). + if not obj.tzinfo: + return self.cls + else: + return (self.cls,) + + cpdef upgrade(self, obj: time, format): + raise NotImplementedError + + +cdef class _BaseTimeTextDumper(_BaseTimeDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class TimeDumper(_BaseTimeTextDumper): + + oid = oids.TIME_OID + + cpdef upgrade(self, obj, format): + if not obj.tzinfo: + return self + else: + return TimeTzDumper(self.cls) + + +@cython.final +cdef class TimeTzDumper(_BaseTimeTextDumper): + + oid = oids.TIMETZ_OID + + +@cython.final +cdef class TimeBinaryDumper(_BaseTimeDumper): + + format = PQ_BINARY + oid = oids.TIME_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = cdt.time_microsecond(obj) + 1000000 * ( + cdt.time_second(obj) + + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj)) + ) + + cdef int64_t *buf = <int64_t *>CDumper.ensure_size( + rv, offset, sizeof(int64_t)) + buf[0] = endian.htobe64(micros) + return sizeof(int64_t) + + cpdef upgrade(self, obj, format): + if not obj.tzinfo: + return self + else: + return TimeTzBinaryDumper(self.cls) + + +@cython.final +cdef class TimeTzBinaryDumper(_BaseTimeDumper): + + format = PQ_BINARY + oid = oids.TIMETZ_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = cdt.time_microsecond(obj) + 1_000_000 * ( + cdt.time_second(obj) + + 60 * (cdt.time_minute(obj) + 60 * <int64_t>cdt.time_hour(obj)) + ) + + off = PyObject_CallFunctionObjArgs(time_utcoffset, <PyObject *>obj, NULL) + cdef int32_t offsec = int(PyObject_CallFunctionObjArgs( + timedelta_total_seconds, <PyObject *>off, NULL)) + + cdef char *buf = CDumper.ensure_size( + rv, offset, sizeof(int64_t) + sizeof(int32_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(-offsec) + + return sizeof(int64_t) + sizeof(int32_t) + + +cdef class _BaseDatetimeDumper(CDumper): + + cpdef get_key(self, obj, format): + # Use (cls,) to report the need to upgrade (downgrade, actually) to a + # dumper for naive timestamp. + if obj.tzinfo: + return self.cls + else: + return (self.cls,) + + cpdef upgrade(self, obj: time, format): + raise NotImplementedError + + +cdef class _BaseDatetimeTextDumper(_BaseDatetimeDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + cdef str s = str(obj) + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class DatetimeDumper(_BaseDatetimeTextDumper): + + oid = oids.TIMESTAMPTZ_OID + + cpdef upgrade(self, obj, format): + if obj.tzinfo: + return self + else: + return DatetimeNoTzDumper(self.cls) + + +@cython.final +cdef class DatetimeNoTzDumper(_BaseDatetimeTextDumper): + + oid = oids.TIMESTAMP_OID + + +@cython.final +cdef class DatetimeBinaryDumper(_BaseDatetimeDumper): + + format = PQ_BINARY + oid = oids.TIMESTAMPTZ_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + delta = obj - pg_datetimetz_epoch + + cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * ( + 86_400 * <int64_t>cdt.timedelta_days(delta) + + <int64_t>cdt.timedelta_seconds(delta)) + + cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + return sizeof(int64_t) + + cpdef upgrade(self, obj, format): + if obj.tzinfo: + return self + else: + return DatetimeNoTzBinaryDumper(self.cls) + + +@cython.final +cdef class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper): + + format = PQ_BINARY + oid = oids.TIMESTAMP_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + delta = obj - pg_datetime_epoch + + cdef int64_t micros = cdt.timedelta_microseconds(delta) + 1_000_000 * ( + 86_400 * <int64_t>cdt.timedelta_days(delta) + + <int64_t>cdt.timedelta_seconds(delta)) + + cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + return sizeof(int64_t) + + +@cython.final +cdef class TimedeltaDumper(CDumper): + + format = PQ_TEXT + oid = oids.INTERVAL_OID + cdef int _style + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_intervalstyle(self._pgconn) + if ds[0] == b's': # sql_standard + self._style = INTERVALSTYLE_SQL_STANDARD + else: # iso_8601, postgres, postgres_verbose + self._style = INTERVALSTYLE_OTHERS + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size; + cdef const char *src + + cdef str s + if self._style == INTERVALSTYLE_OTHERS: + # The comma is parsed ok by PostgreSQL but it's not documented + # and it seems brittle to rely on it. CRDB doesn't consume it well. + s = str(obj).replace(",", "") + else: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + s = "%+d day %+d second %+d microsecond" % ( + obj.days, obj.seconds, obj.microseconds) + + src = PyUnicode_AsUTF8AndSize(s, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class TimedeltaBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INTERVAL_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t micros = ( + 1_000_000 * <int64_t>cdt.timedelta_seconds(obj) + + cdt.timedelta_microseconds(obj)) + cdef int32_t days = cdt.timedelta_days(obj) + + cdef char *buf = CDumper.ensure_size( + rv, offset, sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t)) + (<int64_t *>buf)[0] = endian.htobe64(micros) + (<int32_t *>(buf + sizeof(int64_t)))[0] = endian.htobe32(days) + (<int32_t *>(buf + sizeof(int64_t) + sizeof(int32_t)))[0] = 0 + + return sizeof(int64_t) + sizeof(int32_t) + sizeof(int32_t) + + +@cython.final +cdef class DateLoader(CLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + elif ds[0] == b'G': # German + self._order = ORDER_DMY + elif ds[0] == b'S': # SQL, DMY / MDY + self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY + elif ds[0] == b'P': # Postgres, DMY / MDY + self._order = ORDER_DMY if ds[10] == b'D' else ORDER_MDY + else: + raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + cdef object _error_date(self, const char *data, str msg): + s = bytes(data).decode("utf8", "replace") + if s == "infinity" or len(s.split()[0]) > 10: + raise e.DataError(f"date too large (after year 10K): {s!r}") from None + elif s == "-infinity" or "BC" in s: + raise e.DataError(f"date too small (before year 1): {s!r}") from None + else: + raise e.DataError(f"can't parse date {s!r}: {msg}") from None + + cdef object cload(self, const char *data, size_t length): + if length != 10: + self._error_date(data, "unexpected length") + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + + cdef const char *ptr + cdef const char *end = data + length + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse date {s!r}") + + try: + if self._order == ORDER_YMD: + return cdt.date_new(vals[0], vals[1], vals[2]) + elif self._order == ORDER_DMY: + return cdt.date_new(vals[2], vals[1], vals[0]) + else: + return cdt.date_new(vals[2], vals[0], vals[1]) + except ValueError as ex: + self._error_date(data, str(ex)) + + +@cython.final +cdef class DateBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int days = endian.be32toh((<uint32_t *>data)[0]) + cdef object pydays = days + PG_DATE_EPOCH_DAYS + try: + return PyObject_CallFunctionObjArgs( + date_fromordinal, <PyObject *>pydays, NULL) + except ValueError: + if days < PY_DATE_MIN_DAYS: + raise e.DataError("date too small (before year 1)") from None + else: + raise e.DataError("date too large (after year 10K)") from None + + +@cython.final +cdef class TimeLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + cdef const char *end = data + length + + # Parse the first 3 groups of digits + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse time {s!r}") + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + try: + return cdt.time_new(vals[0], vals[1], vals[2], us, None) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse time {s!r}: {ex}") from None + + +@cython.final +cdef class TimeBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int h, m, s, us + + with cython.cdivision(True): + us = val % 1_000_000 + val //= 1_000_000 + + s = val % 60 + val //= 60 + + m = val % 60 + h = <int>(val // 60) + + try: + return cdt.time_new(h, m, s, us, None) + except ValueError: + raise e.DataError( + f"time not supported by Python: hour={h}" + ) from None + + +@cython.final +cdef class TimetzLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + cdef const char *end = data + length + + # Parse the first 3 groups of digits (time) + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}") + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Parse the timezone + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}") + + tz = _timezone_from_seconds(offsecs) + try: + return cdt.time_new(vals[0], vals[1], vals[2], us, tz) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse timetz {s!r}: {ex}") from None + + +@cython.final +cdef class TimetzBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int32_t off = endian.be32toh((<uint32_t *>(data + sizeof(int64_t)))[0]) + cdef int h, m, s, us + + with cython.cdivision(True): + us = val % 1_000_000 + val //= 1_000_000 + + s = val % 60 + val //= 60 + + m = val % 60 + h = <int>(val // 60) + + tz = _timezone_from_seconds(-off) + try: + return cdt.time_new(h, m, s, us, tz) + except ValueError: + raise e.DataError( + f"time not supported by Python: hour={h}" + ) from None + + +@cython.final +cdef class TimestampLoader(CLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + elif ds[0] == b'G': # German + self._order = ORDER_DMY + elif ds[0] == b'S': # SQL, DMY / MDY + self._order = ORDER_DMY if ds[5] == b'D' else ORDER_MDY + elif ds[0] == b'P': # Postgres, DMY / MDY + self._order = ORDER_PGDM if ds[10] == b'D' else ORDER_PGMD + else: + raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + cdef object cload(self, const char *data, size_t length): + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC + raise _get_timestamp_load_error(self._pgconn, data) from None + + if self._order == ORDER_PGDM or self._order == ORDER_PGMD: + return self._cload_pg(data, end) + + cdef int vals[6] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + + # Parse the first 6 groups of digits (date and time) + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Resolve the YMD order + cdef int y, m, d + if self._order == ORDER_YMD: + y, m, d = vals[0], vals[1], vals[2] + elif self._order == ORDER_DMY: + d, m, y = vals[0], vals[1], vals[2] + else: # self._order == ORDER_MDY + m, d, y = vals[0], vals[1], vals[2] + + try: + return cdt.datetime_new( + y, m, d, vals[3], vals[4], vals[5], us, None) + except ValueError as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + cdef object _cload_pg(self, const char *data, const char *end): + cdef int vals[4] + memset(vals, 0, sizeof(vals)) + cdef const char *ptr + + # Find Wed Jun 02 or Wed 02 Jun + cdef char *seps[3] + seps[0] = strchr(data, b' ') + seps[1] = strchr(seps[0] + 1, b' ') if seps[0] != NULL else NULL + seps[2] = strchr(seps[1] + 1, b' ') if seps[1] != NULL else NULL + if seps[2] == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the following 3 groups of digits (time) + ptr = _parse_date_values(seps[2] + 1, end, vals, 3) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Parse the year + ptr = _parse_date_values(ptr + 1, end, vals + 3, 1) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Resolve the MD order + cdef int m, d + try: + if self._order == ORDER_PGDM: + d = int(seps[0][1 : seps[1] - seps[0]]) + m = _month_abbr[seps[1][1 : seps[2] - seps[1]]] + else: # self._order == ORDER_PGMD + m = _month_abbr[seps[0][1 : seps[1] - seps[0]]] + d = int(seps[1][1 : seps[2] - seps[1]]) + except (KeyError, ValueError) as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + try: + return cdt.datetime_new( + vals[3], m, d, vals[0], vals[1], vals[2], us, None) + except ValueError as ex: + raise _get_timestamp_load_error(self._pgconn, data, ex) from None + + +@cython.final +cdef class TimestampBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int64_t micros, secs, days + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + secs = aval // 1_000_000 + micros = aval % 1_000_000 + + days = secs // 86_400 + secs %= 86_400 + + try: + delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros) + if val > 0: + return pg_datetime_epoch + delta + else: + return pg_datetime_epoch - delta + + except OverflowError: + if val <= 0: + raise e.DataError("timestamp too small (before year 1)") from None + else: + raise e.DataError("timestamp too large (after year 10K)") from None + + +cdef class _BaseTimestamptzLoader(CLoader): + cdef object _time_zone + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + self._time_zone = _timezone_from_connection(self._pgconn) + + +@cython.final +cdef class TimestamptzLoader(_BaseTimestamptzLoader): + + format = PQ_TEXT + cdef int _order + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_datestyle(self._pgconn) + if ds[0] == b'I': # ISO + self._order = ORDER_YMD + else: # Not true, but any non-YMD will do. + self._order = ORDER_DMY + + cdef object cload(self, const char *data, size_t length): + if self._order != ORDER_YMD: + return self._cload_notimpl(data, length) + + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC + raise _get_timestamp_load_error(self._pgconn, data) from None + + cdef int vals[6] + memset(vals, 0, sizeof(vals)) + + # Parse the first 6 groups of digits (date and time) + cdef const char *ptr + ptr = _parse_date_values(data, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + # Parse the microseconds + cdef int us = 0 + if ptr[0] == b".": + ptr = _parse_micros(ptr + 1, &us) + + # Resolve the YMD order + cdef int y, m, d + if self._order == ORDER_YMD: + y, m, d = vals[0], vals[1], vals[2] + elif self._order == ORDER_DMY: + d, m, y = vals[0], vals[1], vals[2] + else: # self._order == ORDER_MDY + m, d, y = vals[0], vals[1], vals[2] + + # Parse the timezone + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) + if ptr == NULL: + raise _get_timestamp_load_error(self._pgconn, data) from None + + tzoff = cdt.timedelta_new(0, offsecs, 0) + + # The return value is a datetime with the timezone of the connection + # (in order to be consistent with the binary loader, which is the only + # thing it can return). So create a temporary datetime object, in utc, + # shift it by the offset parsed from the timestamp, and then move it to + # the connection timezone. + dt = None + try: + dt = cdt.datetime_new( + y, m, d, vals[3], vals[4], vals[5], us, timezone_utc) + dt -= tzoff + return PyObject_CallFunctionObjArgs(datetime_astimezone, + <PyObject *>dt, <PyObject *>self._time_zone, NULL) + except OverflowError as ex: + # If we have created the temporary 'dt' it means that we have a + # datetime close to max, the shift pushed it past max, overflowing. + # In this case return the datetime in a fixed offset timezone. + if dt is not None: + return dt.replace(tzinfo=timezone(tzoff)) + else: + ex1 = ex + except ValueError as ex: + ex1 = ex + + raise _get_timestamp_load_error(self._pgconn, data, ex1) from None + + cdef object _cload_notimpl(self, const char *data, size_t length): + s = bytes(data)[:length].decode("utf8", "replace") + ds = _get_datestyle(self._pgconn).decode() + raise NotImplementedError( + f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" + ) + + +@cython.final +cdef class TimestamptzBinaryLoader(_BaseTimestamptzLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int64_t micros, secs, days + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + secs = aval // 1_000_000 + micros = aval % 1_000_000 + + days = secs // 86_400 + secs %= 86_400 + + try: + delta = cdt.timedelta_new(<int>days, <int>secs, <int>micros) + if val > 0: + dt = pg_datetimetz_epoch + delta + else: + dt = pg_datetimetz_epoch - delta + return PyObject_CallFunctionObjArgs(datetime_astimezone, + <PyObject *>dt, <PyObject *>self._time_zone, NULL) + + except OverflowError: + # If we were asked about a timestamp which would overflow in UTC, + # but not in the desired timezone (e.g. datetime.max at Chicago + # timezone) we can still save the day by shifting the value by the + # timezone offset and then replacing the timezone. + if self._time_zone is not None: + utcoff = self._time_zone.utcoffset( + datetime.min if val < 0 else datetime.max + ) + if utcoff: + usoff = 1_000_000 * int(utcoff.total_seconds()) + try: + ts = pg_datetime_epoch + timedelta( + microseconds=val + usoff + ) + except OverflowError: + pass # will raise downstream + else: + return ts.replace(tzinfo=self._time_zone) + + if val <= 0: + raise e.DataError( + "timestamp too small (before year 1)" + ) from None + else: + raise e.DataError( + "timestamp too large (after year 10K)" + ) from None + + +@cython.final +cdef class IntervalLoader(CLoader): + + format = PQ_TEXT + cdef int _style + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + cdef const char *ds = _get_intervalstyle(self._pgconn) + if ds[0] == b'p' and ds[8] == 0: # postgres + self._style = INTERVALSTYLE_POSTGRES + else: # iso_8601, sql_standard, postgres_verbose + self._style = INTERVALSTYLE_OTHERS + + cdef object cload(self, const char *data, size_t length): + if self._style == INTERVALSTYLE_OTHERS: + return self._cload_notimpl(data, length) + + cdef int days = 0, secs = 0, us = 0 + cdef char sign + cdef int val + cdef const char *ptr = data + cdef const char *sep + cdef const char *end = ptr + length + + # If there are spaces, there is a [+|-]n [days|months|years] + while True: + if ptr[0] == b'-' or ptr[0] == b'+': + sign = ptr[0] + ptr += 1 + else: + sign = 0 + + sep = strchr(ptr, b' ') + if sep == NULL or sep > end: + break + + val = 0 + ptr = _parse_date_values(ptr, end, &val, 1) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + if sign == b'-': + val = -val + + if ptr[1] == b'y': + days = 365 * val + elif ptr[1] == b'm': + days = 30 * val + elif ptr[1] == b'd': + days = val + else: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + # Skip the date part word. + ptr = strchr(ptr + 1, b' ') + if ptr != NULL and ptr < end: + ptr += 1 + else: + break + + # Parse the time part. An eventual sign was already consumed in the loop + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + if ptr != NULL: + ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}") + + secs = vals[2] + 60 * (vals[1] + 60 * vals[0]) + + if ptr[0] == b'.': + ptr = _parse_micros(ptr + 1, &us) + + if sign == b'-': + secs = -secs + us = -us + + try: + return cdt.timedelta_new(days, secs, us) + except OverflowError as ex: + s = bytes(data).decode("utf8", "replace") + raise e.DataError(f"can't parse interval {s!r}: {ex}") from None + + cdef object _cload_notimpl(self, const char *data, size_t length): + s = bytes(data).decode("utf8", "replace") + style = _get_intervalstyle(self._pgconn).decode() + raise NotImplementedError( + f"can't parse interval with IntervalStyle {style!r}: {s!r}" + ) + + +@cython.final +cdef class IntervalBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef int64_t val = endian.be64toh((<uint64_t *>data)[0]) + cdef int32_t days = endian.be32toh( + (<uint32_t *>(data + sizeof(int64_t)))[0]) + cdef int32_t months = endian.be32toh( + (<uint32_t *>(data + sizeof(int64_t) + sizeof(int32_t)))[0]) + + cdef int years + with cython.cdivision(True): + if months > 0: + years = months // 12 + months %= 12 + days += 30 * months + 365 * years + elif months < 0: + months = -months + years = months // 12 + months %= 12 + days -= 30 * months + 365 * years + + # Work only with positive values as the cdivision behaves differently + # with negative values, and cdivision=False adds overhead. + cdef int64_t aval = val if val >= 0 else -val + cdef int us, ussecs, usdays + + # Group the micros in biggers stuff or timedelta_new might overflow + with cython.cdivision(True): + ussecs = <int>(aval // 1_000_000) + us = aval % 1_000_000 + + usdays = ussecs // 86_400 + ussecs %= 86_400 + + if val < 0: + ussecs = -ussecs + usdays = -usdays + us = -us + + try: + return cdt.timedelta_new(days + usdays, ussecs, us) + except OverflowError as ex: + raise e.DataError(f"can't parse interval: {ex}") + + +cdef const char *_parse_date_values( + const char *ptr, const char *end, int *vals, int nvals +): + """ + Parse *nvals* numeric values separated by non-numeric chars. + + Write the result in the *vals* array (assumed zeroed) starting from *start*. + + Return the pointer at the separator after the final digit. + """ + cdef int ival = 0 + while ptr < end: + if b'0' <= ptr[0] <= b'9': + vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0') + else: + ival += 1 + if ival >= nvals: + break + + ptr += 1 + + return ptr + + +cdef const char *_parse_micros(const char *start, int *us): + """ + Parse microseconds from a string. + + Micros are assumed up to 6 digit chars separated by a non-digit. + + Return the pointer at the separator after the final digit. + """ + cdef const char *ptr = start + while ptr[0]: + if b'0' <= ptr[0] <= b'9': + us[0] = us[0] * 10 + (ptr[0] - <char>b'0') + else: + break + + ptr += 1 + + # Pad the fraction of second to get millis + if us[0] and ptr - start < 6: + us[0] *= _uspad[ptr - start] + + return ptr + + +cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end): + """ + Parse a timezone from a string, return Python timezone object. + + Modify the buffer pointer to point at the first character after the + timezone parsed. In case of parse error make it NULL. + """ + cdef const char *ptr = bufptr[0] + cdef char sgn = ptr[0] + + # Parse at most three groups of digits + cdef int vals[3] + memset(vals, 0, sizeof(vals)) + + ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals)) + if ptr == NULL: + return 0 + + cdef int off = 60 * (60 * vals[0] + vals[1]) + vals[2] + return -off if sgn == b"-" else off + + +cdef object _timezone_from_seconds(int sec, __cache={}): + cdef object pysec = sec + cdef PyObject *ptr = PyDict_GetItem(__cache, pysec) + if ptr != NULL: + return <object>ptr + + delta = cdt.timedelta_new(0, sec, 0) + tz = timezone(delta) + __cache[pysec] = tz + return tz + + +cdef object _get_timestamp_load_error( + pq.PGconn pgconn, const char *data, ex: Optional[Exception] = None +): + s = bytes(data).decode("utf8", "replace") + + def is_overflow(s): + if not s: + return False + + ds = _get_datestyle(pgconn) + if not ds.startswith(b"P"): # Postgres + return len(s.split()[0]) > 10 # date is first token + else: + return len(s.split()[-1]) > 4 # year is last token + + if s == "-infinity" or s.endswith("BC"): + return e.DataError("timestamp too small (before year 1): {s!r}") + elif s == "infinity" or is_overflow(s): + return e.DataError(f"timestamp too large (after year 10K): {s!r}") + else: + return e.DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}") + + +cdef _timezones = {} +_timezones[None] = timezone_utc +_timezones[b"UTC"] = timezone_utc + + +cdef object _timezone_from_connection(pq.PGconn pgconn): + """Return the Python timezone info of the connection's timezone.""" + if pgconn is None: + return timezone_utc + + cdef bytes tzname = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"TimeZone") + cdef PyObject *ptr = PyDict_GetItem(_timezones, tzname) + if ptr != NULL: + return <object>ptr + + sname = tzname.decode() if tzname else "UTC" + try: + zi = ZoneInfo(sname) + except (KeyError, OSError): + logger.warning( + "unknown PostgreSQL timezone: %r; will use UTC", sname + ) + zi = timezone_utc + except Exception as ex: + logger.warning( + "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)", + sname, + type(ex).__name__, + ex, + ) + zi = timezone.utc + + _timezones[tzname] = zi + return zi + + +cdef const char *_get_datestyle(pq.PGconn pgconn): + cdef const char *ds + if pgconn is not None: + ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"DateStyle") + if ds is not NULL and ds[0]: + return ds + + return b"ISO, DMY" + + +cdef const char *_get_intervalstyle(pq.PGconn pgconn): + cdef const char *ds + if pgconn is not None: + ds = libpq.PQparameterStatus(pgconn._pgconn_ptr, b"IntervalStyle") + if ds is not NULL and ds[0]: + return ds + + return b"postgres" diff --git a/psycopg_c/psycopg_c/types/numeric.pyx b/psycopg_c/psycopg_c/types/numeric.pyx new file mode 100644 index 0000000..893bdc2 --- /dev/null +++ b/psycopg_c/psycopg_c/types/numeric.pyx @@ -0,0 +1,715 @@ +""" +Cython adapters for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + +from libc.stdint cimport * +from libc.string cimport memcpy, strlen +from cpython.mem cimport PyMem_Free +from cpython.dict cimport PyDict_GetItem, PyDict_SetItem +from cpython.long cimport ( + PyLong_FromString, PyLong_FromLong, PyLong_FromLongLong, + PyLong_FromUnsignedLong, PyLong_AsLongLong) +from cpython.bytes cimport PyBytes_AsStringAndSize +from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble +from cpython.unicode cimport PyUnicode_DecodeUTF8 + +from decimal import Decimal, Context, DefaultContext + +from psycopg_c._psycopg cimport endian +from psycopg import errors as e +from psycopg._wrappers import Int2, Int4, Int8, IntNumeric + +cdef extern from "Python.h": + # work around https://github.com/cython/cython/issues/3909 + double PyOS_string_to_double( + const char *s, char **endptr, PyObject *overflow_exception) except? -1.0 + char *PyOS_double_to_string( + double val, char format_code, int precision, int flags, int *ptype + ) except NULL + int Py_DTSF_ADD_DOT_0 + long long PyLong_AsLongLongAndOverflow(object pylong, int *overflow) except? -1 + + # Missing in cpython/unicode.pxd + const char *PyUnicode_AsUTF8(object unicode) except NULL + + +# defined in numutils.c +cdef extern from *: + """ +int pg_lltoa(int64_t value, char *a); +#define MAXINT8LEN 20 + """ + int pg_lltoa(int64_t value, char *a) + const int MAXINT8LEN + + +cdef class _NumberDumper(CDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_int_to_text(obj, rv, offset) + + def quote(self, obj) -> bytearray: + cdef Py_ssize_t length + + rv = PyByteArray_FromStringAndSize("", 0) + if obj >= 0: + length = self.cdump(obj, rv, 0) + else: + PyByteArray_Resize(rv, 23) + rv[0] = b' ' + length = 1 + self.cdump(obj, rv, 1) + + PyByteArray_Resize(rv, length) + return rv + + +@cython.final +cdef class Int2Dumper(_NumberDumper): + + oid = oids.INT2_OID + + +@cython.final +cdef class Int4Dumper(_NumberDumper): + + oid = oids.INT4_OID + + +@cython.final +cdef class Int8Dumper(_NumberDumper): + + oid = oids.INT8_OID + + +@cython.final +cdef class IntNumericDumper(_NumberDumper): + + oid = oids.NUMERIC_OID + + +@cython.final +cdef class Int2BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT2_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int16_t *buf = <int16_t *>CDumper.ensure_size( + rv, offset, sizeof(int16_t)) + cdef int16_t val = <int16_t>PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint16_t *ptvar = <uint16_t *>(&val) + buf[0] = endian.htobe16(ptvar[0]) + return sizeof(int16_t) + + +@cython.final +cdef class Int4BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT4_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int32_t *buf = <int32_t *>CDumper.ensure_size( + rv, offset, sizeof(int32_t)) + cdef int32_t val = <int32_t>PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint32_t *ptvar = <uint32_t *>(&val) + buf[0] = endian.htobe32(ptvar[0]) + return sizeof(int32_t) + + +@cython.final +cdef class Int8BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.INT8_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef int64_t *buf = <int64_t *>CDumper.ensure_size( + rv, offset, sizeof(int64_t)) + cdef int64_t val = PyLong_AsLongLong(obj) + # swap bytes if needed + cdef uint64_t *ptvar = <uint64_t *>(&val) + buf[0] = endian.htobe64(ptvar[0]) + return sizeof(int64_t) + + +cdef extern from *: + """ +/* Ratio between number of bits required to store a number and number of pg + * decimal digits required (log(2) / log(10_000)). + */ +#define BIT_PER_PGDIGIT 0.07525749891599529 + +/* decimal digits per Postgres "digit" */ +#define DEC_DIGITS 4 + +#define NUMERIC_POS 0x0000 +#define NUMERIC_NEG 0x4000 +#define NUMERIC_NAN 0xC000 +#define NUMERIC_PINF 0xD000 +#define NUMERIC_NINF 0xF000 +""" + const double BIT_PER_PGDIGIT + const int DEC_DIGITS + const int NUMERIC_POS + const int NUMERIC_NEG + const int NUMERIC_NAN + const int NUMERIC_PINF + const int NUMERIC_NINF + + +@cython.final +cdef class IntNumericBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_int_to_numeric_binary(obj, rv, offset) + + +cdef class IntDumper(_NumberDumper): + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + raise TypeError( + f"{type(self).__name__} is a dispatcher to other dumpers:" + " dump() is not supposed to be called" + ) + + cpdef get_key(self, obj, format): + cdef long long val + cdef int overflow + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if overflow: + return IntNumeric + + if INT32_MIN <= obj <= INT32_MAX: + if INT16_MIN <= obj <= INT16_MAX: + return Int2 + else: + return Int4 + else: + if INT64_MIN <= obj <= INT64_MAX: + return Int8 + else: + return IntNumeric + + _int2_dumper = Int2Dumper + _int4_dumper = Int4Dumper + _int8_dumper = Int8Dumper + _int_numeric_dumper = IntNumericDumper + + cpdef upgrade(self, obj, format): + cdef long long val + cdef int overflow + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if overflow: + return self._int_numeric_dumper(IntNumeric) + + if INT32_MIN <= obj <= INT32_MAX: + if INT16_MIN <= obj <= INT16_MAX: + return self._int2_dumper(Int2) + else: + return self._int4_dumper(Int4) + else: + if INT64_MIN <= obj <= INT64_MAX: + return self._int8_dumper(Int8) + else: + return self._int_numeric_dumper(IntNumeric) + + +@cython.final +cdef class IntBinaryDumper(IntDumper): + + format = PQ_BINARY + + _int2_dumper = Int2BinaryDumper + _int4_dumper = Int4BinaryDumper + _int8_dumper = Int8BinaryDumper + _int_numeric_dumper = IntNumericBinaryDumper + + +@cython.final +cdef class IntLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + # if the number ends with a 0 we don't need a copy + if data[length] == b'\0': + return PyLong_FromString(data, NULL, 10) + + # Otherwise we have to copy it aside + if length > MAXINT8LEN: + raise ValueError("string too big for an int") + + cdef char[MAXINT8LEN + 1] buf + memcpy(buf, data, length) + buf[length] = 0 + return PyLong_FromString(buf, NULL, 10) + + + +@cython.final +cdef class Int2BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLong(<int16_t>endian.be16toh((<uint16_t *>data)[0])) + + +@cython.final +cdef class Int4BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLong(<int32_t>endian.be32toh((<uint32_t *>data)[0])) + + +@cython.final +cdef class Int8BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromLongLong(<int64_t>endian.be64toh((<uint64_t *>data)[0])) + + +@cython.final +cdef class OidBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return PyLong_FromUnsignedLong(endian.be32toh((<uint32_t *>data)[0])) + + +cdef class _FloatDumper(CDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef double d = PyFloat_AsDouble(obj) + cdef char *out = PyOS_double_to_string( + d, b'r', 0, Py_DTSF_ADD_DOT_0, NULL) + cdef Py_ssize_t length = strlen(out) + cdef char *tgt = CDumper.ensure_size(rv, offset, length) + memcpy(tgt, out, length) + PyMem_Free(out) + return length + + def quote(self, obj) -> bytes: + value = bytes(self.dump(obj)) + cdef PyObject *ptr = PyDict_GetItem(_special_float, value) + if ptr != NULL: + return <object>ptr + + return value if obj >= 0 else b" " + value + +cdef dict _special_float = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", +} + + +@cython.final +cdef class FloatDumper(_FloatDumper): + + oid = oids.FLOAT8_OID + + +@cython.final +cdef class Float4Dumper(_FloatDumper): + + oid = oids.FLOAT4_OID + + +@cython.final +cdef class FloatBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.FLOAT8_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef double d = PyFloat_AsDouble(obj) + cdef uint64_t *intptr = <uint64_t *>&d + cdef uint64_t *buf = <uint64_t *>CDumper.ensure_size( + rv, offset, sizeof(uint64_t)) + buf[0] = endian.htobe64(intptr[0]) + return sizeof(uint64_t) + + +@cython.final +cdef class Float4BinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.FLOAT4_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef float f = <float>PyFloat_AsDouble(obj) + cdef uint32_t *intptr = <uint32_t *>&f + cdef uint32_t *buf = <uint32_t *>CDumper.ensure_size( + rv, offset, sizeof(uint32_t)) + buf[0] = endian.htobe32(intptr[0]) + return sizeof(uint32_t) + + +@cython.final +cdef class FloatLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + cdef char *endptr + cdef double d = PyOS_string_to_double( + data, &endptr, <PyObject *>OverflowError) + return PyFloat_FromDouble(d) + + +@cython.final +cdef class Float4BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef uint32_t asint = endian.be32toh((<uint32_t *>data)[0]) + # avoid warning: + # dereferencing type-punned pointer will break strict-aliasing rules + cdef char *swp = <char *>&asint + return PyFloat_FromDouble((<float *>swp)[0]) + + +@cython.final +cdef class Float8BinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + cdef uint64_t asint = endian.be64toh((<uint64_t *>data)[0]) + cdef char *swp = <char *>&asint + return PyFloat_FromDouble((<double *>swp)[0]) + + +@cython.final +cdef class DecimalDumper(CDumper): + + format = PQ_TEXT + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_decimal_to_text(obj, rv, offset) + + def quote(self, obj) -> bytes: + value = bytes(self.dump(obj)) + cdef PyObject *ptr = PyDict_GetItem(_special_decimal, value) + if ptr != NULL: + return <object>ptr + + return value if obj >= 0 else b" " + value + +cdef dict _special_decimal = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", +} + + +@cython.final +cdef class NumericLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + s = PyUnicode_DecodeUTF8(<char *>data, length, NULL) + return Decimal(s) + + +cdef dict _decimal_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + +cdef dict _contexts = {} +for _i in range(DefaultContext.prec): + _contexts[_i] = DefaultContext + + +@cython.final +cdef class NumericBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + + cdef uint16_t *data16 = <uint16_t *>data + cdef uint16_t ndigits = endian.be16toh(data16[0]) + cdef int16_t weight = <int16_t>endian.be16toh(data16[1]) + cdef uint16_t sign = endian.be16toh(data16[2]) + cdef uint16_t dscale = endian.be16toh(data16[3]) + cdef int shift + cdef int i + cdef PyObject *pctx + cdef object key + + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + if length != (4 + ndigits) * sizeof(uint16_t): + raise e.DataError("bad ndigits in numeric binary representation") + + val = 0 + for i in range(ndigits): + val *= 10_000 + val += endian.be16toh(data16[i + 4]) + + shift = dscale - (ndigits - weight - 1) * DEC_DIGITS + + key = (weight + 2) * DEC_DIGITS + dscale + pctx = PyDict_GetItem(_contexts, key) + if pctx == NULL: + ctx = Context(prec=key) + PyDict_SetItem(_contexts, key, ctx) + pctx = <PyObject *>ctx + + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, <object>pctx) + .shift(shift, <object>pctx) + ) + else: + try: + return _decimal_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") + + +@cython.final +cdef class DecimalBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + return dump_decimal_to_numeric_binary(obj, rv, offset) + + +@cython.final +cdef class NumericDumper(CDumper): + + format = PQ_TEXT + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + if isinstance(obj, int): + return dump_int_to_text(obj, rv, offset) + else: + return dump_decimal_to_text(obj, rv, offset) + + +@cython.final +cdef class NumericBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.NUMERIC_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + if isinstance(obj, int): + return dump_int_to_numeric_binary(obj, rv, offset) + else: + return dump_decimal_to_numeric_binary(obj, rv, offset) + + +cdef Py_ssize_t dump_decimal_to_text(obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *src + cdef Py_ssize_t length + cdef char *buf + + b = bytes(str(obj), "utf-8") + PyBytes_AsStringAndSize(b, &src, &length) + + if src[0] != b's': + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, src, length) + + else: # convert sNaN to NaN + length = 3 # NaN + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, b"NaN", length) + + return length + + +cdef extern from *: + """ +/* Weights of py digits into a pg digit according to their positions. */ +static const int pydigit_weights[] = {1000, 100, 10, 1}; +""" + const int[4] pydigit_weights + + +@cython.cdivision(True) +cdef Py_ssize_t dump_decimal_to_numeric_binary( + obj, bytearray rv, Py_ssize_t offset +) except -1: + + # TODO: this implementation is about 30% slower than the text dump. + # This might be probably optimised by accessing the C structure of + # the Decimal object, if available, which would save the creation of + # several intermediate Python objects (the DecimalTuple, the digits + # tuple, and then accessing them). + + cdef object t = obj.as_tuple() + cdef int sign = t[0] + cdef tuple digits = t[1] + cdef uint16_t *buf + cdef Py_ssize_t length + + cdef object pyexp = t[2] + cdef const char *bexp + if not isinstance(pyexp, int): + # Handle inf, nan + length = 4 * sizeof(uint16_t) + buf = <uint16_t *>CDumper.ensure_size(rv, offset, length) + buf[0] = 0 + buf[1] = 0 + buf[3] = 0 + bexp = PyUnicode_AsUTF8(pyexp) + if bexp[0] == b'n' or bexp[0] == b'N': + buf[2] = endian.htobe16(NUMERIC_NAN) + elif bexp[0] == b'F': + if sign: + buf[2] = endian.htobe16(NUMERIC_NINF) + else: + buf[2] = endian.htobe16(NUMERIC_PINF) + else: + raise e.DataError(f"unexpected decimal exponent: {pyexp}") + return length + + cdef int exp = pyexp + cdef uint16_t ndigits = <uint16_t>len(digits) + + # Find the last nonzero digit + cdef int nzdigits = ndigits + while nzdigits > 0 and digits[nzdigits - 1] == 0: + nzdigits -= 1 + + cdef uint16_t dscale + if exp <= 0: + dscale = -exp + else: + dscale = 0 + # align the py digits to the pg digits if there's some py exponent + ndigits += exp % DEC_DIGITS + + if nzdigits == 0: + length = 4 * sizeof(uint16_t) + buf = <uint16_t *>CDumper.ensure_size(rv, offset, length) + buf[0] = 0 # ndigits + buf[1] = 0 # weight + buf[2] = endian.htobe16(NUMERIC_POS) # sign + buf[3] = endian.htobe16(dscale) + return length + + # Equivalent of 0-padding left to align the py digits to the pg digits + # but without changing the digits tuple. + cdef int wi = 0 + cdef int mod = (ndigits - dscale) % DEC_DIGITS + if mod < 0: + # the difference between C and Py % operator + mod += 4 + if mod: + wi = DEC_DIGITS - mod + ndigits += wi + + cdef int tmp = nzdigits + wi + cdef int pgdigits = tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1) + length = (pgdigits + 4) * sizeof(uint16_t) + buf = <uint16_t*>CDumper.ensure_size(rv, offset, length) + buf[0] = endian.htobe16(pgdigits) + buf[1] = endian.htobe16(<int16_t>((ndigits + exp) // DEC_DIGITS - 1)) + buf[2] = endian.htobe16(NUMERIC_NEG) if sign else endian.htobe16(NUMERIC_POS) + buf[3] = endian.htobe16(dscale) + + cdef uint16_t pgdigit = 0 + cdef int bi = 4 + for i in range(nzdigits): + pgdigit += pydigit_weights[wi] * <int>(digits[i]) + wi += 1 + if wi >= DEC_DIGITS: + buf[bi] = endian.htobe16(pgdigit) + pgdigit = wi = 0 + bi += 1 + + if pgdigit: + buf[bi] = endian.htobe16(pgdigit) + + return length + + +cdef Py_ssize_t dump_int_to_text(obj, bytearray rv, Py_ssize_t offset) except -1: + cdef long long val + cdef int overflow + cdef char *buf + cdef char *src + cdef Py_ssize_t length + + # Ensure an int or a subclass. The 'is' type check is fast. + # Passing a float must give an error, but passing an Enum should work. + if type(obj) is not int and not isinstance(obj, int): + raise e.DataError(f"integer expected, got {type(obj).__name__!r}") + + val = PyLong_AsLongLongAndOverflow(obj, &overflow) + if not overflow: + buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1) + length = pg_lltoa(val, buf) + else: + b = bytes(str(obj), "utf-8") + PyBytes_AsStringAndSize(b, &src, &length) + buf = CDumper.ensure_size(rv, offset, length) + memcpy(buf, src, length) + + return length + + +cdef Py_ssize_t dump_int_to_numeric_binary(obj, bytearray rv, Py_ssize_t offset) except -1: + # Calculate the number of PG digits required to store the number + cdef uint16_t ndigits + ndigits = <uint16_t>((<int>obj.bit_length()) * BIT_PER_PGDIGIT) + 1 + + cdef uint16_t sign = NUMERIC_POS + if obj < 0: + sign = NUMERIC_NEG + obj = -obj + + cdef Py_ssize_t length = sizeof(uint16_t) * (ndigits + 4) + cdef uint16_t *buf + buf = <uint16_t *><void *>CDumper.ensure_size(rv, offset, length) + buf[0] = endian.htobe16(ndigits) + buf[1] = endian.htobe16(ndigits - 1) # weight + buf[2] = endian.htobe16(sign) + buf[3] = 0 # dscale + + cdef int i = 4 + ndigits - 1 + cdef uint16_t rem + while obj: + rem = obj % 10000 + obj //= 10000 + buf[i] = endian.htobe16(rem) + i -= 1 + while i > 3: + buf[i] = 0 + i -= 1 + + return length diff --git a/psycopg_c/psycopg_c/types/numutils.c b/psycopg_c/psycopg_c/types/numutils.c new file mode 100644 index 0000000..4be7108 --- /dev/null +++ b/psycopg_c/psycopg_c/types/numutils.c @@ -0,0 +1,243 @@ +/* + * Utilities to deal with numbers. + * + * Copyright (C) 2020 The Psycopg Team + * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + */ + +#include <stdint.h> +#include <string.h> + +#include "pg_config.h" + + +/* + * 64-bit integers + */ +#ifdef HAVE_LONG_INT_64 +/* Plain "long int" fits, use it */ + +# ifndef HAVE_INT64 +typedef long int int64; +# endif +# ifndef HAVE_UINT64 +typedef unsigned long int uint64; +# endif +# define INT64CONST(x) (x##L) +# define UINT64CONST(x) (x##UL) +#elif defined(HAVE_LONG_LONG_INT_64) +/* We have working support for "long long int", use that */ + +# ifndef HAVE_INT64 +typedef long long int int64; +# endif +# ifndef HAVE_UINT64 +typedef unsigned long long int uint64; +# endif +# define INT64CONST(x) (x##LL) +# define UINT64CONST(x) (x##ULL) +#else +/* neither HAVE_LONG_INT_64 nor HAVE_LONG_LONG_INT_64 */ +# error must have a working 64-bit integer datatype +#endif + + +#ifndef HAVE__BUILTIN_CLZ +static const uint8_t pg_leftmost_one_pos[256] = { + 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 +}; +#endif + +static const char DIGIT_TABLE[200] = { + '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', + '7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', + '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2', + '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9', + '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', + '7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', + '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5', + '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9', + '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', + '7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', + '7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8', + '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9', + '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', + '7', '9', '8', '9', '9' +}; + + +/* + * pg_leftmost_one_pos64 + * As above, but for a 64-bit word. + */ +static inline int +pg_leftmost_one_pos64(uint64_t word) +{ +#ifdef HAVE__BUILTIN_CLZ +#if defined(HAVE_LONG_INT_64) + return 63 - __builtin_clzl(word); +#elif defined(HAVE_LONG_LONG_INT_64) + return 63 - __builtin_clzll(word); +#else +#error must have a working 64-bit integer datatype +#endif +#else /* !HAVE__BUILTIN_CLZ */ + int shift = 64 - 8; + + while ((word >> shift) == 0) + shift -= 8; + + return shift + pg_leftmost_one_pos[(word >> shift) & 255]; +#endif /* HAVE__BUILTIN_CLZ */ +} + + +static inline int +decimalLength64(const uint64_t v) +{ + int t; + static const uint64_t PowersOfTen[] = { + UINT64CONST(1), UINT64CONST(10), + UINT64CONST(100), UINT64CONST(1000), + UINT64CONST(10000), UINT64CONST(100000), + UINT64CONST(1000000), UINT64CONST(10000000), + UINT64CONST(100000000), UINT64CONST(1000000000), + UINT64CONST(10000000000), UINT64CONST(100000000000), + UINT64CONST(1000000000000), UINT64CONST(10000000000000), + UINT64CONST(100000000000000), UINT64CONST(1000000000000000), + UINT64CONST(10000000000000000), UINT64CONST(100000000000000000), + UINT64CONST(1000000000000000000), UINT64CONST(10000000000000000000) + }; + + /* + * Compute base-10 logarithm by dividing the base-2 logarithm by a + * good-enough approximation of the base-2 logarithm of 10 + */ + t = (pg_leftmost_one_pos64(v) + 1) * 1233 / 4096; + return t + (v >= PowersOfTen[t]); +} + + +/* + * Get the decimal representation, not NUL-terminated, and return the length of + * same. Caller must ensure that a points to at least MAXINT8LEN bytes. + */ +int +pg_ulltoa_n(uint64_t value, char *a) +{ + int olength, + i = 0; + uint32_t value2; + + /* Degenerate case */ + if (value == 0) + { + *a = '0'; + return 1; + } + + olength = decimalLength64(value); + + /* Compute the result string. */ + while (value >= 100000000) + { + const uint64_t q = value / 100000000; + uint32_t value2 = (uint32_t) (value - 100000000 * q); + + const uint32_t c = value2 % 10000; + const uint32_t d = value2 / 10000; + const uint32_t c0 = (c % 100) << 1; + const uint32_t c1 = (c / 100) << 1; + const uint32_t d0 = (d % 100) << 1; + const uint32_t d1 = (d / 100) << 1; + + char *pos = a + olength - i; + + value = q; + + memcpy(pos - 2, DIGIT_TABLE + c0, 2); + memcpy(pos - 4, DIGIT_TABLE + c1, 2); + memcpy(pos - 6, DIGIT_TABLE + d0, 2); + memcpy(pos - 8, DIGIT_TABLE + d1, 2); + i += 8; + } + + /* Switch to 32-bit for speed */ + value2 = (uint32_t) value; + + if (value2 >= 10000) + { + const uint32_t c = value2 - 10000 * (value2 / 10000); + const uint32_t c0 = (c % 100) << 1; + const uint32_t c1 = (c / 100) << 1; + + char *pos = a + olength - i; + + value2 /= 10000; + + memcpy(pos - 2, DIGIT_TABLE + c0, 2); + memcpy(pos - 4, DIGIT_TABLE + c1, 2); + i += 4; + } + if (value2 >= 100) + { + const uint32_t c = (value2 % 100) << 1; + char *pos = a + olength - i; + + value2 /= 100; + + memcpy(pos - 2, DIGIT_TABLE + c, 2); + i += 2; + } + if (value2 >= 10) + { + const uint32_t c = value2 << 1; + char *pos = a + olength - i; + + memcpy(pos - 2, DIGIT_TABLE + c, 2); + } + else + *a = (char) ('0' + value2); + + return olength; +} + +/* + * pg_lltoa: converts a signed 64-bit integer to its string representation and + * returns strlen(a). + * + * Caller must ensure that 'a' points to enough memory to hold the result + * (at least MAXINT8LEN + 1 bytes, counting a leading sign and trailing NUL). + */ +int +pg_lltoa(int64_t value, char *a) +{ + uint64_t uvalue = value; + int len = 0; + + if (value < 0) + { + uvalue = (uint64_t) 0 - uvalue; + a[len++] = '-'; + } + + len += pg_ulltoa_n(uvalue, a + len); + a[len] = '\0'; + return len; +} diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx new file mode 100644 index 0000000..da18b01 --- /dev/null +++ b/psycopg_c/psycopg_c/types/string.pyx @@ -0,0 +1,315 @@ +""" +Cython adapters for textual types. +""" + +# Copyright (C) 2020 The Psycopg Team + +cimport cython + +from libc.string cimport memcpy, memchr +from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize +from cpython.unicode cimport ( + PyUnicode_AsEncodedString, + PyUnicode_AsUTF8String, + PyUnicode_CheckExact, + PyUnicode_Decode, + PyUnicode_DecodeUTF8, +) + +from psycopg_c.pq cimport libpq, Escaping, _buffer_as_string_and_size + +from psycopg import errors as e +from psycopg._encodings import pg2pyenc + +cdef extern from "Python.h": + const char *PyUnicode_AsUTF8AndSize(unicode obj, Py_ssize_t *size) except NULL + + +cdef class _BaseStrDumper(CDumper): + cdef int is_utf8 + cdef char *encoding + cdef bytes _bytes_encoding # needed to keep `encoding` alive + + def __cinit__(self, cls, context: Optional[AdaptContext] = None): + + self.is_utf8 = 0 + self.encoding = "utf-8" + cdef const char *pgenc + + if self._pgconn is not None: + pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding") + if pgenc == NULL or pgenc == b"UTF8": + self._bytes_encoding = b"utf-8" + self.is_utf8 = 1 + else: + self._bytes_encoding = pg2pyenc(pgenc).encode() + if self._bytes_encoding == b"ascii": + self.is_utf8 = 1 + self.encoding = PyBytes_AsString(self._bytes_encoding) + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + # the server will raise DataError subclass if the string contains 0x00 + cdef Py_ssize_t size; + cdef const char *src + + if self.is_utf8: + # Probably the fastest path, but doesn't work with subclasses + if PyUnicode_CheckExact(obj): + src = PyUnicode_AsUTF8AndSize(obj, &size) + else: + b = PyUnicode_AsUTF8String(obj) + PyBytes_AsStringAndSize(b, <char **>&src, &size) + else: + b = PyUnicode_AsEncodedString(obj, self.encoding, NULL) + PyBytes_AsStringAndSize(b, <char **>&src, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +cdef class _StrBinaryDumper(_BaseStrDumper): + + format = PQ_BINARY + + +@cython.final +cdef class StrBinaryDumper(_StrBinaryDumper): + + oid = oids.TEXT_OID + + +@cython.final +cdef class StrBinaryDumperVarchar(_StrBinaryDumper): + + oid = oids.VARCHAR_OID + + +@cython.final +cdef class StrBinaryDumperName(_StrBinaryDumper): + + oid = oids.NAME_OID + + +cdef class _StrDumper(_BaseStrDumper): + + format = PQ_TEXT + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef Py_ssize_t size = StrBinaryDumper.cdump(self, obj, rv, offset) + + # Like the binary dump, but check for 0, or the string will be truncated + cdef const char *buf = PyByteArray_AS_STRING(rv) + if NULL != memchr(buf + offset, 0x00, size): + raise e.DataError( + "PostgreSQL text fields cannot contain NUL (0x00) bytes" + ) + return size + + +@cython.final +cdef class StrDumper(_StrDumper): + + oid = oids.TEXT_OID + + +@cython.final +cdef class StrDumperVarchar(_StrDumper): + + oid = oids.VARCHAR_OID + + +@cython.final +cdef class StrDumperName(_StrDumper): + + oid = oids.NAME_OID + + +@cython.final +cdef class StrDumperUnknown(_StrDumper): + pass + + +cdef class _TextLoader(CLoader): + + format = PQ_TEXT + + cdef int is_utf8 + cdef char *encoding + cdef bytes _bytes_encoding # needed to keep `encoding` alive + + def __cinit__(self, oid: int, context: Optional[AdaptContext] = None): + + self.is_utf8 = 0 + self.encoding = "utf-8" + cdef const char *pgenc + + if self._pgconn is not None: + pgenc = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"client_encoding") + if pgenc == NULL or pgenc == b"UTF8": + self._bytes_encoding = b"utf-8" + self.is_utf8 = 1 + else: + self._bytes_encoding = pg2pyenc(pgenc).encode() + + if pgenc == b"SQL_ASCII": + self.encoding = NULL + else: + self.encoding = PyBytes_AsString(self._bytes_encoding) + + cdef object cload(self, const char *data, size_t length): + if self.is_utf8: + return PyUnicode_DecodeUTF8(<char *>data, length, NULL) + elif self.encoding: + return PyUnicode_Decode(<char *>data, length, self.encoding, NULL) + else: + return data[:length] + +@cython.final +cdef class TextLoader(_TextLoader): + + format = PQ_TEXT + + +@cython.final +cdef class TextBinaryLoader(_TextLoader): + + format = PQ_BINARY + + +@cython.final +cdef class BytesDumper(CDumper): + + format = PQ_TEXT + oid = oids.BYTEA_OID + + # 0: not set, 1: just single "'" quote, 3: " E'" quote + cdef int _qplen + + def __cinit__(self): + self._qplen = 0 + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + + _buffer_as_string_and_size(obj, &ptr, &length) + + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + out = libpq.PQescapeByteaConn( + self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_out) + else: + out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_out) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {length} bytes" + ) + + len_out -= 1 # out includes final 0 + cdef char *buf = CDumper.ensure_size(rv, offset, len_out) + memcpy(buf, out, len_out) + libpq.PQfreemem(out) + return len_out + + def quote(self, obj): + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + cdef const char *scs + + escaped = self.dump(obj) + _buffer_as_string_and_size(escaped, &ptr, &length) + + rv = PyByteArray_FromStringAndSize("", 0) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self._pgconn is not None: + if not self._qplen: + scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, + b"standard_conforming_strings") + if scs and scs[0] == b'o' and scs[1] == b"n": # == "on" + self._qplen = 1 + else: + self._qplen = 3 + + PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + if self._qplen == 1: + ptr_out[0] = b"'" + else: + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + self._qplen, ptr, length) + ptr_out[length + self._qplen] = b"'" + return rv + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + PyByteArray_Resize(rv, length + 4) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + esc = Escaping() + if esc.escape_bytea(b"\x00") == b"\\000": + rv = bytes(rv).replace(b"\\", b"\\\\") + + return rv + + +@cython.final +cdef class BytesBinaryDumper(CDumper): + + format = PQ_BINARY + oid = oids.BYTEA_OID + + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: + cdef char *src + cdef Py_ssize_t size; + _buffer_as_string_and_size(obj, &src, &size) + + cdef char *buf = CDumper.ensure_size(rv, offset, size) + memcpy(buf, src, size) + return size + + +@cython.final +cdef class ByteaLoader(CLoader): + + format = PQ_TEXT + + cdef object cload(self, const char *data, size_t length): + cdef size_t len_out + cdef unsigned char *out = libpq.PQunescapeBytea( + <const unsigned char *>data, &len_out) + if out is NULL: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = out[:len_out] + libpq.PQfreemem(out) + return rv + + +@cython.final +cdef class ByteaBinaryLoader(CLoader): + + format = PQ_BINARY + + cdef object cload(self, const char *data, size_t length): + return data[:length] diff --git a/psycopg_c/psycopg_c/version.py b/psycopg_c/psycopg_c/version.py new file mode 100644 index 0000000..5c989c2 --- /dev/null +++ b/psycopg_c/psycopg_c/version.py @@ -0,0 +1,11 @@ +""" +psycopg-c distribution version file. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ +__version__ = "3.1.7" + +# also change psycopg/psycopg/version.py accordingly. diff --git a/psycopg_c/pyproject.toml b/psycopg_c/pyproject.toml new file mode 100644 index 0000000..f0d7a3f --- /dev/null +++ b/psycopg_c/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0.0a11"] +build-backend = "setuptools.build_meta" diff --git a/psycopg_c/setup.cfg b/psycopg_c/setup.cfg new file mode 100644 index 0000000..6c5c93c --- /dev/null +++ b/psycopg_c/setup.cfg @@ -0,0 +1,57 @@ +[metadata] +name = psycopg-c +description = PostgreSQL database adapter for Python -- C optimisation distribution +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg-c/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +setup_requires = Cython >= 3.0.0a11 +packages = find: +zip_safe = False + +[options.package_data] +# NOTE: do not include .pyx files: they shouldn't be in the sdist +# package, so that build is only performed from the .c files (which are +# distributed instead). +psycopg_c = + py.typed + *.pyi + *.pxd + _psycopg/*.pxd + pq/*.pxd + +# In the psycopg-binary distribution don't include cython-related files. +psycopg_binary = + py.typed + *.pyi diff --git a/psycopg_c/setup.py b/psycopg_c/setup.py new file mode 100644 index 0000000..c6da3a1 --- /dev/null +++ b/psycopg_c/setup.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - optimisation package +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import sys +import subprocess as sp + +from setuptools import setup, Extension +from distutils.command.build_ext import build_ext +from distutils import log + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +with open("psycopg_c/version.py") as f: + data = f.read() + m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data) + if m is None: + raise Exception(f"cannot find version in {f.name}") + version = m.group(1) + + +def get_config(what: str) -> str: + pg_config = "pg_config" + try: + out = sp.run([pg_config, f"--{what}"], stdout=sp.PIPE, check=True) + except Exception as e: + log.error(f"couldn't run {pg_config!r} --{what}: %s", e) + raise + else: + return out.stdout.strip().decode() + + +class psycopg_build_ext(build_ext): + def finalize_options(self) -> None: + self._setup_ext_build() + super().finalize_options() + + def _setup_ext_build(self) -> None: + cythonize = None + + # In the sdist there are not .pyx, only c, so we don't need Cython. + # Otherwise Cython is a requirement and it is used to compile pyx to c. + if os.path.exists("psycopg_c/_psycopg.pyx"): + from Cython.Build import cythonize + + # Add include and lib dir for the libpq. + includedir = get_config("includedir") + libdir = get_config("libdir") + for ext in self.distribution.ext_modules: + ext.include_dirs.append(includedir) + ext.library_dirs.append(libdir) + + if sys.platform == "win32": + # For __imp_htons and others + ext.libraries.append("ws2_32") + + if cythonize is not None: + for ext in self.distribution.ext_modules: + for i in range(len(ext.sources)): + base, fext = os.path.splitext(ext.sources[i]) + if fext == ".c" and os.path.exists(base + ".pyx"): + ext.sources[i] = base + ".pyx" + + self.distribution.ext_modules = cythonize( + self.distribution.ext_modules, + language_level=3, + compiler_directives={ + "always_allow_keywords": False, + }, + annotate=False, # enable to get an html view of the C module + ) + else: + self.distribution.ext_modules = [pgext, pqext] + + +# MSVC requires an explicit "libpq" +libpq = "pq" if sys.platform != "win32" else "libpq" + +# Some details missing, to be finished by psycopg_build_ext.finalize_options +pgext = Extension( + "psycopg_c._psycopg", + [ + "psycopg_c/_psycopg.c", + "psycopg_c/types/numutils.c", + ], + libraries=[libpq], + include_dirs=[], +) + +pqext = Extension( + "psycopg_c.pq", + ["psycopg_c/pq.c"], + libraries=[libpq], + include_dirs=[], +) + +setup( + version=version, + ext_modules=[pgext, pqext], + cmdclass={"build_ext": psycopg_build_ext}, +) diff --git a/psycopg_pool/.flake8 b/psycopg_pool/.flake8 new file mode 100644 index 0000000..2ae629c --- /dev/null +++ b/psycopg_pool/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 diff --git a/psycopg_pool/LICENSE.txt b/psycopg_pool/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg_pool/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg_pool/README.rst b/psycopg_pool/README.rst new file mode 100644 index 0000000..6e6b32c --- /dev/null +++ b/psycopg_pool/README.rst @@ -0,0 +1,24 @@ +Psycopg 3: PostgreSQL database adapter for Python - Connection Pool +=================================================================== + +This distribution contains the optional connection pool package +`psycopg_pool`__. + +.. __: https://www.psycopg.org/psycopg3/docs/advanced/pool.html + +This package is kept separate from the main ``psycopg`` package because it is +likely that it will follow a different release cycle. + +You can also install this package using:: + + pip install "psycopg[pool]" + +Please read `the project readme`__ and `the installation documentation`__ for +more details. + +.. __: https://github.com/psycopg/psycopg#readme +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html + #installing-the-connection-pool + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg_pool/psycopg_pool/__init__.py b/psycopg_pool/psycopg_pool/__init__.py new file mode 100644 index 0000000..e4d975f --- /dev/null +++ b/psycopg_pool/psycopg_pool/__init__.py @@ -0,0 +1,22 @@ +""" +psycopg connection pool package +""" + +# Copyright (C) 2021 The Psycopg Team + +from .pool import ConnectionPool +from .pool_async import AsyncConnectionPool +from .null_pool import NullConnectionPool +from .null_pool_async import AsyncNullConnectionPool +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from .version import __version__ as __version__ # noqa: F401 + +__all__ = [ + "AsyncConnectionPool", + "AsyncNullConnectionPool", + "ConnectionPool", + "NullConnectionPool", + "PoolClosed", + "PoolTimeout", + "TooManyRequests", +] diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py new file mode 100644 index 0000000..9fb2b9b --- /dev/null +++ b/psycopg_pool/psycopg_pool/_compat.py @@ -0,0 +1,51 @@ +""" +compatibility functions for different Python versions +""" + +# Copyright (C) 2021 The Psycopg Team + +import sys +import asyncio +from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar +from typing_extensions import TypeAlias + +import psycopg.errors as e + +T = TypeVar("T") +FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] + +if sys.version_info >= (3, 8): + create_task = asyncio.create_task + Task = asyncio.Task + +else: + + def create_task( + coro: FutureT[T], name: Optional[str] = None + ) -> "asyncio.Future[T]": + return asyncio.create_task(coro) + + Task = asyncio.Future + +if sys.version_info >= (3, 9): + from collections import Counter, deque as Deque +else: + from typing import Counter, Deque + +__all__ = [ + "Counter", + "Deque", + "Task", + "create_task", +] + +# Workaround for psycopg < 3.0.8. +# Timeout on NullPool connection mignt not work correctly. +try: + ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout +except AttributeError: + + class DummyConnectionTimeout(e.OperationalError): + pass + + ConnectionTimeout = DummyConnectionTimeout diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py new file mode 100644 index 0000000..298ea68 --- /dev/null +++ b/psycopg_pool/psycopg_pool/base.py @@ -0,0 +1,230 @@ +""" +psycopg connection pool base class and functionalities. +""" + +# Copyright (C) 2021 The Psycopg Team + +from time import monotonic +from random import random +from typing import Any, Callable, Dict, Generic, Optional, Tuple + +from psycopg import errors as e +from psycopg.abc import ConnectionType + +from .errors import PoolClosed +from ._compat import Counter, Deque + + +class BasePool(Generic[ConnectionType]): + + # Used to generate pool names + _num_pool = 0 + + # Stats keys + _POOL_MIN = "pool_min" + _POOL_MAX = "pool_max" + _POOL_SIZE = "pool_size" + _POOL_AVAILABLE = "pool_available" + _REQUESTS_WAITING = "requests_waiting" + _REQUESTS_NUM = "requests_num" + _REQUESTS_QUEUED = "requests_queued" + _REQUESTS_WAIT_MS = "requests_wait_ms" + _REQUESTS_ERRORS = "requests_errors" + _USAGE_MS = "usage_ms" + _RETURNS_BAD = "returns_bad" + _CONNECTIONS_NUM = "connections_num" + _CONNECTIONS_MS = "connections_ms" + _CONNECTIONS_ERRORS = "connections_errors" + _CONNECTIONS_LOST = "connections_lost" + + def __init__( + self, + conninfo: str = "", + *, + kwargs: Optional[Dict[str, Any]] = None, + min_size: int = 4, + max_size: Optional[int] = None, + open: bool = True, + name: Optional[str] = None, + timeout: float = 30.0, + max_waiting: int = 0, + max_lifetime: float = 60 * 60.0, + max_idle: float = 10 * 60.0, + reconnect_timeout: float = 5 * 60.0, + reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None, + num_workers: int = 3, + ): + min_size, max_size = self._check_size(min_size, max_size) + + if not name: + num = BasePool._num_pool = BasePool._num_pool + 1 + name = f"pool-{num}" + + if num_workers < 1: + raise ValueError("num_workers must be at least 1") + + self.conninfo = conninfo + self.kwargs: Dict[str, Any] = kwargs or {} + self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None] + self._reconnect_failed = reconnect_failed or (lambda pool: None) + self.name = name + self._min_size = min_size + self._max_size = max_size + self.timeout = timeout + self.max_waiting = max_waiting + self.reconnect_timeout = reconnect_timeout + self.max_lifetime = max_lifetime + self.max_idle = max_idle + self.num_workers = num_workers + + self._nconns = min_size # currently in the pool, out, being prepared + self._pool = Deque[ConnectionType]() + self._stats = Counter[str]() + + # Min number of connections in the pool in a max_idle unit of time. + # It is reset periodically by the ShrinkPool scheduled task. + # It is used to shrink back the pool if maxcon > min_size and extra + # connections have been acquired, if we notice that in the last + # max_idle interval they weren't all used. + self._nconns_min = min_size + + # Flag to allow the pool to grow only one connection at time. In case + # of spike, if threads are allowed to grow in parallel and connection + # time is slow, there won't be any thread available to return the + # connections to the pool. + self._growing = False + + self._opened = False + self._closed = True + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__module__}.{self.__class__.__name__}" + f" {self.name!r} at 0x{id(self):x}>" + ) + + @property + def min_size(self) -> int: + return self._min_size + + @property + def max_size(self) -> int: + return self._max_size + + @property + def closed(self) -> bool: + """`!True` if the pool is closed.""" + return self._closed + + def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]: + if max_size is None: + max_size = min_size + + if min_size < 0: + raise ValueError("min_size cannot be negative") + if max_size < min_size: + raise ValueError("max_size must be greater or equal than min_size") + if min_size == max_size == 0: + raise ValueError("if min_size is 0 max_size must be greater or than 0") + + return min_size, max_size + + def _check_open(self) -> None: + if self._closed and self._opened: + raise e.OperationalError( + "pool has already been opened/closed and cannot be reused" + ) + + def _check_open_getconn(self) -> None: + if self._closed: + if self._opened: + raise PoolClosed(f"the pool {self.name!r} is already closed") + else: + raise PoolClosed(f"the pool {self.name!r} is not open yet") + + def _check_pool_putconn(self, conn: ConnectionType) -> None: + pool = getattr(conn, "_pool", None) + if pool is self: + return + + if pool: + msg = f"it comes from pool {pool.name!r}" + else: + msg = "it doesn't come from any pool" + raise ValueError( + f"can't return connection to pool {self.name!r}, {msg}: {conn}" + ) + + def get_stats(self) -> Dict[str, int]: + """ + Return current stats about the pool usage. + """ + rv = dict(self._stats) + rv.update(self._get_measures()) + return rv + + def pop_stats(self) -> Dict[str, int]: + """ + Return current stats about the pool usage. + + After the call, all the counters are reset to zero. + """ + stats, self._stats = self._stats, Counter() + rv = dict(stats) + rv.update(self._get_measures()) + return rv + + def _get_measures(self) -> Dict[str, int]: + """ + Return immediate measures of the pool (not counters). + """ + return { + self._POOL_MIN: self._min_size, + self._POOL_MAX: self._max_size, + self._POOL_SIZE: self._nconns, + self._POOL_AVAILABLE: len(self._pool), + } + + @classmethod + def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float: + """ + Add a random value to *value* between *min_pc* and *max_pc* percent. + """ + return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc) + + def _set_connection_expiry_date(self, conn: ConnectionType) -> None: + """Set an expiry date on a connection. + + Add some randomness to avoid mass reconnection. + """ + conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0) + + +class ConnectionAttempt: + """Keep the state of a connection attempt.""" + + INITIAL_DELAY = 1.0 + DELAY_JITTER = 0.1 + DELAY_BACKOFF = 2.0 + + def __init__(self, *, reconnect_timeout: float): + self.reconnect_timeout = reconnect_timeout + self.delay = 0.0 + self.give_up_at = 0.0 + + def update_delay(self, now: float) -> None: + """Calculate how long to wait for a new connection attempt""" + if self.delay == 0.0: + self.give_up_at = now + self.reconnect_timeout + self.delay = BasePool._jitter( + self.INITIAL_DELAY, -self.DELAY_JITTER, self.DELAY_JITTER + ) + else: + self.delay *= self.DELAY_BACKOFF + + if self.delay + now > self.give_up_at: + self.delay = max(0.0, self.give_up_at - now) + + def time_to_give_up(self, now: float) -> bool: + """Return True if we are tired of trying to connect. Meh.""" + return self.give_up_at > 0.0 and now >= self.give_up_at diff --git a/psycopg_pool/psycopg_pool/errors.py b/psycopg_pool/psycopg_pool/errors.py new file mode 100644 index 0000000..9e672ad --- /dev/null +++ b/psycopg_pool/psycopg_pool/errors.py @@ -0,0 +1,25 @@ +""" +Connection pool errors. +""" + +# Copyright (C) 2021 The Psycopg Team + +from psycopg import errors as e + + +class PoolClosed(e.OperationalError): + """Attempt to get a connection from a closed pool.""" + + __module__ = "psycopg_pool" + + +class PoolTimeout(e.OperationalError): + """The pool couldn't provide a connection in acceptable time.""" + + __module__ = "psycopg_pool" + + +class TooManyRequests(e.OperationalError): + """Too many requests in the queue waiting for a connection from the pool.""" + + __module__ = "psycopg_pool" diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py new file mode 100644 index 0000000..c0a77c2 --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -0,0 +1,159 @@ +""" +Psycopg null connection pools +""" + +# Copyright (C) 2022 The Psycopg Team + +import logging +import threading +from typing import Any, Optional, Tuple + +from psycopg import Connection +from psycopg.pq import TransactionStatus + +from .pool import ConnectionPool, AddConnection +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout + +logger = logging.getLogger("psycopg.pool") + + +class _BaseNullConnectionPool: + def __init__( + self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any + ): + super().__init__( # type: ignore[call-arg] + conninfo, *args, min_size=min_size, **kwargs + ) + + def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]: + if max_size is None: + max_size = min_size + + if min_size != 0: + raise ValueError("null pools must have min_size = 0") + if max_size < min_size: + raise ValueError("max_size must be greater or equal than min_size") + + return min_size, max_size + + def _start_initial_tasks(self) -> None: + # Null pools don't have background tasks to fill connections + # or to grow/shrink. + return + + def _maybe_grow_pool(self) -> None: + # null pools don't grow + pass + + +class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): + def wait(self, timeout: float = 30.0) -> None: + """ + Create a connection for test. + + Calling this function will verify that the connectivity with the + database works as expected. However the connection will not be stored + in the pool. + + Close the pool, and raise `PoolTimeout`, if not ready within *timeout* + sec. + """ + self._check_open_getconn() + + with self._lock: + assert not self._pool_full_event + self._pool_full_event = threading.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + if not self._pool_full_event.wait(timeout): + self.close() # stop all the threads + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") + + with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[Connection[Any]]: + conn: Optional[Connection[Any]] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + with self._lock: + if not self._closed and self._waiting: + return False + + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + conn.close() + self._nconns -= 1 + return True + + def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + """Change the size of the pool during runtime. + + Only *max_size* can be changed; *min_size* must remain 0. + """ + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + with self._lock: + self._min_size = min_size + self._max_size = max_size + + def check(self) -> None: + """No-op, as the pool doesn't have connections in its state.""" + pass + + def _add_to_pool(self, conn: Connection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shouldn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py new file mode 100644 index 0000000..ae9d207 --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -0,0 +1,122 @@ +""" +psycopg asynchronous null connection pool +""" + +# Copyright (C) 2022 The Psycopg Team + +import asyncio +import logging +from typing import Any, Optional + +from psycopg import AsyncConnection +from psycopg.pq import TransactionStatus + +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout +from .null_pool import _BaseNullConnectionPool +from .pool_async import AsyncConnectionPool, AddConnection + +logger = logging.getLogger("psycopg.pool") + + +class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): + async def wait(self, timeout: float = 30.0) -> None: + self._check_open_getconn() + + async with self._lock: + assert not self._pool_full_event + self._pool_full_event = asyncio.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the tasks + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) from None + + async with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + async def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[AsyncConnection[Any]]: + conn: Optional[AsyncConnection[Any]] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = await self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + # Close the connection if no client is waiting for it, or if the pool + # is closed. For extra refcare remove the pool reference from it. + # Maintain the stats. + async with self._lock: + if not self._closed and self._waiting: + return False + + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + await conn.close() + self._nconns -= 1 + return True + + async def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + async with self._lock: + self._min_size = min_size + self._max_size = max_size + + async def check(self) -> None: + pass + + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + await conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shouldn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py new file mode 100644 index 0000000..609d95d --- /dev/null +++ b/psycopg_pool/psycopg_pool/pool.py @@ -0,0 +1,839 @@ +""" +psycopg synchronous connection pool +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +import threading +from abc import ABC, abstractmethod +from time import monotonic +from queue import Queue, Empty +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List +from typing import Optional, Sequence, Type +from weakref import ref +from contextlib import contextmanager + +from psycopg import errors as e +from psycopg import Connection +from psycopg.pq import TransactionStatus + +from .base import ConnectionAttempt, BasePool +from .sched import Scheduler +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from ._compat import Deque + +logger = logging.getLogger("psycopg.pool") + + +class ConnectionPool(BasePool[Connection[Any]]): + def __init__( + self, + conninfo: str = "", + *, + open: bool = True, + connection_class: Type[Connection[Any]] = Connection, + configure: Optional[Callable[[Connection[Any]], None]] = None, + reset: Optional[Callable[[Connection[Any]], None]] = None, + **kwargs: Any, + ): + self.connection_class = connection_class + self._configure = configure + self._reset = reset + + self._lock = threading.RLock() + self._waiting = Deque["WaitingClient"]() + + # to notify that the pool is full + self._pool_full_event: Optional[threading.Event] = None + + self._sched = Scheduler() + self._sched_runner: Optional[threading.Thread] = None + self._tasks: "Queue[MaintenanceTask]" = Queue() + self._workers: List[threading.Thread] = [] + + super().__init__(conninfo, **kwargs) + + if open: + self.open() + + def __del__(self) -> None: + # If the '_closed' property is not set we probably failed in __init__. + # Don't try anything complicated as probably it won't work. + if getattr(self, "_closed", True): + return + + self._stop_workers() + + def wait(self, timeout: float = 30.0) -> None: + """ + Wait for the pool to be full (with `min_size` connections) after creation. + + Close the pool, and raise `PoolTimeout`, if not ready within *timeout* + sec. + + Calling this method is not mandatory: you can try and use the pool + immediately after its creation. The first client will be served as soon + as a connection is ready. You can use this method if you prefer your + program to terminate in case the environment is not configured + properly, rather than trying to stay up the hardest it can. + """ + self._check_open_getconn() + + with self._lock: + assert not self._pool_full_event + if len(self._pool) >= self._min_size: + return + self._pool_full_event = threading.Event() + + logger.info("waiting for pool %r initialization", self.name) + if not self._pool_full_event.wait(timeout): + self.close() # stop all the threads + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") + + with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + @contextmanager + def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]: + """Context manager to obtain a connection from the pool. + + Return the connection immediately if available, otherwise wait up to + *timeout* or `self.timeout` seconds and throw `PoolTimeout` if a + connection is not available in time. + + Upon context exit, return the connection to the pool. Apply the normal + :ref:`connection context behaviour <with-connection>` (commit/rollback + the transaction in case of success/error). If the connection is no more + in working state, replace it with a new one. + """ + conn = self.getconn(timeout=timeout) + t0 = monotonic() + try: + with conn: + yield conn + finally: + t1 = monotonic() + self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) + self.putconn(conn) + + def getconn(self, timeout: Optional[float] = None) -> Connection[Any]: + """Obtain a connection from the pool. + + You should preferably use `connection()`. Use this function only if + it is not possible to use the connection as context manager. + + After using this function you *must* call a corresponding `putconn()`: + failing to do so will deplete the pool. A depleted pool is a sad pool: + you don't want a depleted pool. + """ + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + with self._lock: + self._check_open_getconn() + conn = self._get_ready_connection(timeout) + if not conn: + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = WaitingClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If there is space for the pool to grow, let's do it + self._maybe_grow_pool() + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if not conn: + if timeout is None: + timeout = self.timeout + try: + conn = pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + # Note that this property shouldn't be set while the connection is in + # the pool, to avoid to create a reference loop. + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[Connection[Any]]: + """Return a connection, if the client deserves one.""" + conn: Optional[Connection[Any]] = None + if self._pool: + # Take a connection ready out of the pool + conn = self._pool.popleft() + if len(self._pool) < self._nconns_min: + self._nconns_min = len(self._pool) + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_grow_pool(self) -> None: + # Allow only one thread at time to grow the pool (or returning + # connections might be starved). + if self._nconns >= self._max_size or self._growing: + return + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self._growing = True + self.run_task(AddConnection(self, growing=True)) + + def putconn(self, conn: Connection[Any]) -> None: + """Return a connection to the loving hands of its pool. + + Use this function only paired with a `getconn()`. You don't need to use + it if you use the much more comfortable `connection()` context manager. + """ + # Quick check to discard the wrong connection + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + + if self._maybe_close_connection(conn): + return + + # Use a worker to perform eventual maintenance work in a separate thread + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + self._return_connection(conn) + + def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + """Close a returned connection if necessary. + + Return `!True if the connection was closed. + """ + # If the pool is closed just close the connection instead of returning + # it to the pool. For extra refcare remove the pool reference from it. + if not self._closed: + return False + + conn._pool = None + conn.close() + return True + + def open(self, wait: bool = False, timeout: float = 30.0) -> None: + """Open the pool by starting connecting and and accepting clients. + + If *wait* is `!False`, return immediately and let the background worker + fill the pool if `min_size` > 0. Otherwise wait up to *timeout* seconds + for the requested number of connections to be ready (see `wait()` for + details). + + It is safe to call `!open()` again on a pool already open (because the + method was already called, or because the pool context was entered, or + because the pool was initialized with *open* = `!True`) but you cannot + currently re-open a closed pool. + """ + with self._lock: + self._open() + + if wait: + self.wait(timeout=timeout) + + def _open(self) -> None: + if not self._closed: + return + + self._check_open() + + self._closed = False + self._opened = True + + self._start_workers() + self._start_initial_tasks() + + def _start_workers(self) -> None: + self._sched_runner = threading.Thread( + target=self._sched.run, + name=f"{self.name}-scheduler", + daemon=True, + ) + assert not self._workers + for i in range(self.num_workers): + t = threading.Thread( + target=self.worker, + args=(self._tasks,), + name=f"{self.name}-worker-{i}", + daemon=True, + ) + self._workers.append(t) + + # The object state is complete. Start the worker threads + self._sched_runner.start() + for t in self._workers: + t.start() + + def _start_initial_tasks(self) -> None: + # populate the pool with initial min_size connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over min_size have + # remained unused. + self.schedule_task(ShrinkPool(self), self.max_idle) + + def close(self, timeout: float = 5.0) -> None: + """Close the pool and make it unavailable to new clients. + + All the waiting and future clients will fail to acquire a connection + with a `PoolClosed` exception. Currently used connections will not be + closed until returned to the pool. + + Wait *timeout* seconds for threads to terminate their job, if positive. + If the timeout expires the pool is closed anyway, although it may raise + some warnings on exit. + """ + if self._closed: + return + + with self._lock: + self._closed = True + logger.debug("pool %r closed", self.name) + + # Take waiting client and pool connections out of the state + waiting = list(self._waiting) + self._waiting.clear() + connections = list(self._pool) + self._pool.clear() + + # Now that the flag _closed is set, getconn will fail immediately, + # putconn will just close the returned connection. + self._stop_workers(waiting, connections, timeout) + + def _stop_workers( + self, + waiting_clients: Sequence["WaitingClient"] = (), + connections: Sequence[Connection[Any]] = (), + timeout: float = 0.0, + ) -> None: + + # Stop the scheduler + self._sched.enter(0, None) + + # Stop the worker threads + workers, self._workers = self._workers[:], [] + for i in range(len(workers)): + self.run_task(StopWorker(self)) + + # Signal to eventual clients in the queue that business is closed. + for pos in waiting_clients: + pos.fail(PoolClosed(f"the pool {self.name!r} is closed")) + + # Close the connections still in the pool + for conn in connections: + conn.close() + + # Wait for the worker threads to terminate + assert self._sched_runner is not None + sched_runner, self._sched_runner = self._sched_runner, None + if timeout > 0: + for t in [sched_runner] + workers: + if not t.is_alive(): + continue + t.join(timeout) + if t.is_alive(): + logger.warning( + "couldn't stop thread %s in pool %r within %s seconds", + t, + self.name, + timeout, + ) + + def __enter__(self) -> "ConnectionPool": + self.open() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + """Change the size of the pool during runtime.""" + min_size, max_size = self._check_size(min_size, max_size) + + ngrow = max(0, min_size - self._min_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + with self._lock: + self._min_size = min_size + self._max_size = max_size + self._nconns += ngrow + + for i in range(ngrow): + self.run_task(AddConnection(self)) + + def check(self) -> None: + """Verify the state of the connections currently in the pool. + + Test each connection: if it works return it to the pool, otherwise + dispose of it and create a new one. + """ + with self._lock: + conns = list(self._pool) + self._pool.clear() + + # Give a chance to the pool to grow if it has no connection. + # In case there are enough connection, or the pool is already + # growing, this is a no-op. + self._maybe_grow_pool() + + while conns: + conn = conns.pop() + try: + conn.execute("SELECT 1") + if conn.pgconn.transaction_status == TransactionStatus.INTRANS: + conn.rollback() + except Exception: + self._stats[self._CONNECTIONS_LOST] += 1 + logger.warning("discarding broken connection: %s", conn) + self.run_task(AddConnection(self)) + else: + self._add_to_pool(conn) + + def reconnect_failed(self) -> None: + """ + Called when reconnection failed for longer than `reconnect_timeout`. + """ + self._reconnect_failed(self) + + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker thread.""" + self._tasks.put_nowait(task) + + def schedule_task(self, task: "MaintenanceTask", delay: float) -> None: + """Run a maintenance task in a worker thread in the future.""" + self._sched.enter(delay, task.tick) + + _WORKER_TIMEOUT = 60.0 + + @classmethod + def worker(cls, q: "Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a separate thread. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + # Don't make all the workers time out at the same moment + timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1) + while True: + # Use a timeout to make the wait interruptible + try: + task = q.get(timeout=timeout) + except Empty: + continue + + if isinstance(task, StopWorker): + logger.debug( + "terminating working thread %s", + threading.current_thread().name, + ) + return + + # Run the task. Make sure don't die in the attempt. + try: + task.run() + except Exception as ex: + logger.warning( + "task run %s failed: %s: %s", + task, + ex.__class__.__name__, + ex, + ) + + def _connect(self, timeout: Optional[float] = None) -> Connection[Any]: + """Return a new connection configured for the pool.""" + self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) + t0 = monotonic() + try: + conn: Connection[Any] + conn = self.connection_class.connect(self.conninfo, **kwargs) + except Exception: + self._stats[self._CONNECTIONS_ERRORS] += 1 + raise + else: + t1 = monotonic() + self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0)) + + conn._pool = self + + if self._configure: + self._configure(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by configure function" + f" {self._configure}: discarded" + ) + + # Set an expiry date, with some randomness to avoid mass reconnection + self._set_connection_expiry_date(conn) + return conn + + def _add_connection( + self, attempt: Optional[ConnectionAttempt], growing: bool = False + ) -> None: + """Try to connect and add the connection to the pool. + + If failed, reschedule a new attempt in the future for a few times, then + give up, decrease the pool connections number and call + `self.reconnect_failed()`. + + """ + now = monotonic() + if not attempt: + attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout) + + try: + conn = self._connect() + except Exception as ex: + logger.warning(f"error connecting in {self.name!r}: {ex}") + if attempt.time_to_give_up(now): + logger.warning( + "reconnection attempt in pool %r failed after %s sec", + self.name, + self.reconnect_timeout, + ) + with self._lock: + self._nconns -= 1 + # If we have given up with a growing attempt, allow a new one. + if growing and self._growing: + self._growing = False + self.reconnect_failed() + else: + attempt.update_delay(now) + self.schedule_task( + AddConnection(self, attempt, growing=growing), + attempt.delay, + ) + return + + logger.info("adding new connection to the pool") + self._add_to_pool(conn) + if growing: + with self._lock: + # Keep on growing if the pool is not full yet, or if there are + # clients waiting and the pool can extend. + if self._nconns < self._min_size or ( + self._nconns < self._max_size and self._waiting + ): + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self.run_task(AddConnection(self, growing=True)) + else: + self._growing = False + + def _return_connection(self, conn: Connection[Any]) -> None: + """ + Return a connection to the pool after usage. + """ + self._reset_connection(conn) + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + # Connection no more in working state: create a new one. + self.run_task(AddConnection(self)) + logger.warning("discarding closed connection: %s", conn) + return + + # Check if the connection is past its best before date + if conn._expire_at <= monotonic(): + self.run_task(AddConnection(self)) + logger.info("discarding expired connection") + conn.close() + return + + self._add_to_pool(conn) + + def _add_to_pool(self, conn: Connection[Any]) -> None: + """ + Add a connection to the pool. + + The connection can be a fresh one or one already used in the pool. + + If a client is already waiting for a connection pass it on, otherwise + put it back into the pool + """ + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if pos.set(conn): + break + else: + # No client waiting for a connection: put it back into the pool + self._pool.append(conn) + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is full. + if self._pool_full_event and len(self._pool) >= self._min_size: + self._pool_full_event.set() + + def _reset_connection(self, conn: Connection[Any]) -> None: + """ + Bring a connection to IDLE state or close it. + """ + status = conn.pgconn.transaction_status + if status == TransactionStatus.IDLE: + pass + + elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + # Connection returned with an active transaction + logger.warning("rolling back returned connection: %s", conn) + try: + conn.rollback() + except Exception as ex: + logger.warning( + "rollback failed: %s: %s. Discarding connection %s", + ex.__class__.__name__, + ex, + conn, + ) + conn.close() + + elif status == TransactionStatus.ACTIVE: + # Connection returned during an operation. Bad... just close it. + logger.warning("closing returned connection: %s", conn) + conn.close() + + if not conn.closed and self._reset: + try: + self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + conn.close() + + def _shrink_pool(self) -> None: + to_close: Optional[Connection[Any]] = None + + with self._lock: + # Reset the min number of connections used + nconns_min = self._nconns_min + self._nconns_min = len(self._pool) + + # If the pool can shrink and connections were unused, drop one + if self._nconns > self._min_size and nconns_min > 0: + to_close = self._pool.popleft() + self._nconns -= 1 + self._nconns_min -= 1 + + if to_close: + logger.info( + "shrinking pool %r to %s because %s unused connections" + " in the last %s sec", + self.name, + self._nconns, + nconns_min, + self.max_idle, + ) + to_close.close() + + def _get_measures(self) -> Dict[str, int]: + rv = super()._get_measures() + rv[self._REQUESTS_WAITING] = len(self._waiting) + return rv + + +class WaitingClient: + """A position in a queue for a client waiting for a connection.""" + + __slots__ = ("conn", "error", "_cond") + + def __init__(self) -> None: + self.conn: Optional[Connection[Any]] = None + self.error: Optional[Exception] = None + + # The WaitingClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = threading.Condition() + + def wait(self, timeout: float) -> Connection[Any]: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + with self._cond: + if not (self.conn or self.error): + if not self._cond.wait(timeout): + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error + raise self.error + + def set(self, conn: Connection[Any]) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out). + """ + with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out). + """ + with self._cond: + if self.conn or self.error: + return False + + self.error = error + self._cond.notify_all() + return True + + +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "ConnectionPool"): + self.pool = ref(pool) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "<pool is gone>" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + def run(self) -> None: + """Run the task. + + This usually happens in a worker thread. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task run discarded: %s", self) + return + + logger.debug("task running in %s: %s", threading.current_thread().name, self) + self._run(pool) + + def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler thread. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task tick discarded: %s", self) + return + + pool.run_task(self) + + @abstractmethod + def _run(self, pool: "ConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance thread to terminate.""" + + def _run(self, pool: "ConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "ConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + growing: bool = False, + ): + super().__init__(pool) + self.attempt = attempt + self.growing = growing + + def _run(self, pool: "ConnectionPool") -> None: + pool._add_connection(self.attempt, growing=self.growing) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"): + super().__init__(pool) + self.conn = conn + + def _run(self, pool: "ConnectionPool") -> None: + pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + def _run(self, pool: "ConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + pool.schedule_task(self, pool.max_idle) + pool._shrink_pool() diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py new file mode 100644 index 0000000..0ea6e9a --- /dev/null +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -0,0 +1,784 @@ +""" +psycopg asynchronous connection pool +""" + +# Copyright (C) 2021 The Psycopg Team + +import asyncio +import logging +from abc import ABC, abstractmethod +from time import monotonic +from types import TracebackType +from typing import Any, AsyncIterator, Awaitable, Callable +from typing import Dict, List, Optional, Sequence, Type +from weakref import ref +from contextlib import asynccontextmanager + +from psycopg import errors as e +from psycopg import AsyncConnection +from psycopg.pq import TransactionStatus + +from .base import ConnectionAttempt, BasePool +from .sched import AsyncScheduler +from .errors import PoolClosed, PoolTimeout, TooManyRequests +from ._compat import Task, create_task, Deque + +logger = logging.getLogger("psycopg.pool") + + +class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): + def __init__( + self, + conninfo: str = "", + *, + open: bool = True, + connection_class: Type[AsyncConnection[Any]] = AsyncConnection, + configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + **kwargs: Any, + ): + self.connection_class = connection_class + self._configure = configure + self._reset = reset + + # asyncio objects, created on open to attach them to the right loop. + self._lock: asyncio.Lock + self._sched: AsyncScheduler + self._tasks: "asyncio.Queue[MaintenanceTask]" + + self._waiting = Deque["AsyncClient"]() + + # to notify that the pool is full + self._pool_full_event: Optional[asyncio.Event] = None + + self._sched_runner: Optional[Task[None]] = None + self._workers: List[Task[None]] = [] + + super().__init__(conninfo, **kwargs) + + if open: + self._open() + + async def wait(self, timeout: float = 30.0) -> None: + self._check_open_getconn() + + async with self._lock: + assert not self._pool_full_event + if len(self._pool) >= self._min_size: + return + self._pool_full_event = asyncio.Event() + + logger.info("waiting for pool %r initialization", self.name) + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the tasks + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) from None + + async with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + @asynccontextmanager + async def connection( + self, timeout: Optional[float] = None + ) -> AsyncIterator[AsyncConnection[Any]]: + conn = await self.getconn(timeout=timeout) + t0 = monotonic() + try: + async with conn: + yield conn + finally: + t1 = monotonic() + self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) + await self.putconn(conn) + + async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + self._check_open_getconn() + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + async with self._lock: + conn = await self._get_ready_connection(timeout) + if not conn: + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = AsyncClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If there is space for the pool to grow, let's do it + self._maybe_grow_pool() + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if not conn: + if timeout is None: + timeout = self.timeout + try: + conn = await pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + # Note that this property shouldn't be set while the connection is in + # the pool, to avoid to create a reference loop. + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + async def _get_ready_connection( + self, timeout: Optional[float] + ) -> Optional[AsyncConnection[Any]]: + conn: Optional[AsyncConnection[Any]] = None + if self._pool: + # Take a connection ready out of the pool + conn = self._pool.popleft() + if len(self._pool) < self._nconns_min: + self._nconns_min = len(self._pool) + elif self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has already" + f" {len(self._waiting)} requests waiting" + ) + return conn + + def _maybe_grow_pool(self) -> None: + # Allow only one task at time to grow the pool (or returning + # connections might be starved). + if self._nconns < self._max_size and not self._growing: + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self._growing = True + self.run_task(AddConnection(self, growing=True)) + + async def putconn(self, conn: AsyncConnection[Any]) -> None: + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + if await self._maybe_close_connection(conn): + return + + # Use a worker to perform eventual maintenance work in a separate task + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + await self._return_connection(conn) + + async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + # If the pool is closed just close the connection instead of returning + # it to the pool. For extra refcare remove the pool reference from it. + if not self._closed: + return False + + conn._pool = None + await conn.close() + return True + + async def open(self, wait: bool = False, timeout: float = 30.0) -> None: + # Make sure the lock is created after there is an event loop + try: + self._lock + except AttributeError: + self._lock = asyncio.Lock() + + async with self._lock: + self._open() + + if wait: + await self.wait(timeout=timeout) + + def _open(self) -> None: + if not self._closed: + return + + # Throw a RuntimeError if the pool is open outside a running loop. + asyncio.get_running_loop() + + self._check_open() + + # Create these objects now to attach them to the right loop. + # See #219 + self._tasks = asyncio.Queue() + self._sched = AsyncScheduler() + # This has been most likely, but not necessarily, created in `open()`. + try: + self._lock + except AttributeError: + self._lock = asyncio.Lock() + + self._closed = False + self._opened = True + + self._start_workers() + self._start_initial_tasks() + + def _start_workers(self) -> None: + self._sched_runner = create_task( + self._sched.run(), name=f"{self.name}-scheduler" + ) + for i in range(self.num_workers): + t = create_task( + self.worker(self._tasks), + name=f"{self.name}-worker-{i}", + ) + self._workers.append(t) + + def _start_initial_tasks(self) -> None: + # populate the pool with initial min_size connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over min_size have + # remained unused. + self.run_task(Schedule(self, ShrinkPool(self), self.max_idle)) + + async def close(self, timeout: float = 5.0) -> None: + if self._closed: + return + + async with self._lock: + self._closed = True + logger.debug("pool %r closed", self.name) + + # Take waiting client and pool connections out of the state + waiting = list(self._waiting) + self._waiting.clear() + connections = list(self._pool) + self._pool.clear() + + # Now that the flag _closed is set, getconn will fail immediately, + # putconn will just close the returned connection. + await self._stop_workers(waiting, connections, timeout) + + async def _stop_workers( + self, + waiting_clients: Sequence["AsyncClient"] = (), + connections: Sequence[AsyncConnection[Any]] = (), + timeout: float = 0.0, + ) -> None: + # Stop the scheduler + await self._sched.enter(0, None) + + # Stop the worker tasks + workers, self._workers = self._workers[:], [] + for w in workers: + self.run_task(StopWorker(self)) + + # Signal to eventual clients in the queue that business is closed. + for pos in waiting_clients: + await pos.fail(PoolClosed(f"the pool {self.name!r} is closed")) + + # Close the connections still in the pool + for conn in connections: + await conn.close() + + # Wait for the worker tasks to terminate + assert self._sched_runner is not None + sched_runner, self._sched_runner = self._sched_runner, None + wait = asyncio.gather(sched_runner, *workers) + try: + if timeout > 0: + await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) + else: + await wait + except asyncio.TimeoutError: + logger.warning( + "couldn't stop pool %r tasks within %s seconds", + self.name, + timeout, + ) + + async def __aenter__(self) -> "AsyncConnectionPool": + await self.open() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + ngrow = max(0, min_size - self._min_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + async with self._lock: + self._min_size = min_size + self._max_size = max_size + self._nconns += ngrow + + for i in range(ngrow): + self.run_task(AddConnection(self)) + + async def check(self) -> None: + async with self._lock: + conns = list(self._pool) + self._pool.clear() + + # Give a chance to the pool to grow if it has no connection. + # In case there are enough connection, or the pool is already + # growing, this is a no-op. + self._maybe_grow_pool() + + while conns: + conn = conns.pop() + try: + await conn.execute("SELECT 1") + if conn.pgconn.transaction_status == TransactionStatus.INTRANS: + await conn.rollback() + except Exception: + self._stats[self._CONNECTIONS_LOST] += 1 + logger.warning("discarding broken connection: %s", conn) + self.run_task(AddConnection(self)) + else: + await self._add_to_pool(conn) + + def reconnect_failed(self) -> None: + """ + Called when reconnection failed for longer than `reconnect_timeout`. + """ + self._reconnect_failed(self) + + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker.""" + self._tasks.put_nowait(task) + + async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None: + """Run a maintenance task in a worker in the future.""" + await self._sched.enter(delay, task.tick) + + @classmethod + async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a task. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + while True: + task = await q.get() + + if isinstance(task, StopWorker): + logger.debug("terminating working task") + return + + # Run the task. Make sure don't die in the attempt. + try: + await task.run() + except Exception as ex: + logger.warning( + "task run %s failed: %s: %s", + task, + ex.__class__.__name__, + ex, + ) + + async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) + t0 = monotonic() + try: + conn: AsyncConnection[Any] + conn = await self.connection_class.connect(self.conninfo, **kwargs) + except Exception: + self._stats[self._CONNECTIONS_ERRORS] += 1 + raise + else: + t1 = monotonic() + self._stats[self._CONNECTIONS_MS] += int(1000.0 * (t1 - t0)) + + conn._pool = self + + if self._configure: + await self._configure(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by configure function" + f" {self._configure}: discarded" + ) + + # Set an expiry date, with some randomness to avoid mass reconnection + self._set_connection_expiry_date(conn) + return conn + + async def _add_connection( + self, attempt: Optional[ConnectionAttempt], growing: bool = False + ) -> None: + """Try to connect and add the connection to the pool. + + If failed, reschedule a new attempt in the future for a few times, then + give up, decrease the pool connections number and call + `self.reconnect_failed()`. + + """ + now = monotonic() + if not attempt: + attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout) + + try: + conn = await self._connect() + except Exception as ex: + logger.warning(f"error connecting in {self.name!r}: {ex}") + if attempt.time_to_give_up(now): + logger.warning( + "reconnection attempt in pool %r failed after %s sec", + self.name, + self.reconnect_timeout, + ) + async with self._lock: + self._nconns -= 1 + # If we have given up with a growing attempt, allow a new one. + if growing and self._growing: + self._growing = False + self.reconnect_failed() + else: + attempt.update_delay(now) + await self.schedule_task( + AddConnection(self, attempt, growing=growing), + attempt.delay, + ) + return + + logger.info("adding new connection to the pool") + await self._add_to_pool(conn) + if growing: + async with self._lock: + # Keep on growing if the pool is not full yet, or if there are + # clients waiting and the pool can extend. + if self._nconns < self._min_size or ( + self._nconns < self._max_size and self._waiting + ): + self._nconns += 1 + logger.info("growing pool %r to %s", self.name, self._nconns) + self.run_task(AddConnection(self, growing=True)) + else: + self._growing = False + + async def _return_connection(self, conn: AsyncConnection[Any]) -> None: + """ + Return a connection to the pool after usage. + """ + await self._reset_connection(conn) + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + # Connection no more in working state: create a new one. + self.run_task(AddConnection(self)) + logger.warning("discarding closed connection: %s", conn) + return + + # Check if the connection is past its best before date + if conn._expire_at <= monotonic(): + self.run_task(AddConnection(self)) + logger.info("discarding expired connection") + await conn.close() + return + + await self._add_to_pool(conn) + + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + """ + Add a connection to the pool. + + The connection can be a fresh one or one already used in the pool. + + If a client is already waiting for a connection pass it on, otherwise + put it back into the pool + """ + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: put it back into the pool + self._pool.append(conn) + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is full. + if self._pool_full_event and len(self._pool) >= self._min_size: + self._pool_full_event.set() + + async def _reset_connection(self, conn: AsyncConnection[Any]) -> None: + """ + Bring a connection to IDLE state or close it. + """ + status = conn.pgconn.transaction_status + if status == TransactionStatus.IDLE: + pass + + elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + # Connection returned with an active transaction + logger.warning("rolling back returned connection: %s", conn) + try: + await conn.rollback() + except Exception as ex: + logger.warning( + "rollback failed: %s: %s. Discarding connection %s", + ex.__class__.__name__, + ex, + conn, + ) + await conn.close() + + elif status == TransactionStatus.ACTIVE: + # Connection returned during an operation. Bad... just close it. + logger.warning("closing returned connection: %s", conn) + await conn.close() + + if not conn.closed and self._reset: + try: + await self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + sname = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {sname} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + await conn.close() + + async def _shrink_pool(self) -> None: + to_close: Optional[AsyncConnection[Any]] = None + + async with self._lock: + # Reset the min number of connections used + nconns_min = self._nconns_min + self._nconns_min = len(self._pool) + + # If the pool can shrink and connections were unused, drop one + if self._nconns > self._min_size and nconns_min > 0: + to_close = self._pool.popleft() + self._nconns -= 1 + self._nconns_min -= 1 + + if to_close: + logger.info( + "shrinking pool %r to %s because %s unused connections" + " in the last %s sec", + self.name, + self._nconns, + nconns_min, + self.max_idle, + ) + await to_close.close() + + def _get_measures(self) -> Dict[str, int]: + rv = super()._get_measures() + rv[self._REQUESTS_WAITING] = len(self._waiting) + return rv + + +class AsyncClient: + """A position in a queue for a client waiting for a connection.""" + + __slots__ = ("conn", "error", "_cond") + + def __init__(self) -> None: + self.conn: Optional[AsyncConnection[Any]] = None + self.error: Optional[Exception] = None + + # The AsyncClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = asyncio.Condition() + + async def wait(self, timeout: float) -> AsyncConnection[Any]: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + async with self._cond: + if not (self.conn or self.error): + try: + await asyncio.wait_for(self._cond.wait(), timeout) + except asyncio.TimeoutError: + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error + raise self.error + + async def set(self, conn: AsyncConnection[Any]) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + async def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.error = error + self._cond.notify_all() + return True + + +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "AsyncConnectionPool"): + self.pool = ref(pool) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "<pool is gone>" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + async def run(self) -> None: + """Run the task. + + This usually happens in a worker. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task run discarded: %s", self) + return + + await self._run(pool) + + async def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler task. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + logger.debug("task tick discarded: %s", self) + return + + pool.run_task(self) + + @abstractmethod + async def _run(self, pool: "AsyncConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance worker to terminate.""" + + async def _run(self, pool: "AsyncConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "AsyncConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + growing: bool = False, + ): + super().__init__(pool) + self.attempt = attempt + self.growing = growing + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._add_connection(self.attempt, growing=self.growing) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"): + super().__init__(pool) + self.conn = conn + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + async def _run(self, pool: "AsyncConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + await pool.schedule_task(self, pool.max_idle) + await pool._shrink_pool() + + +class Schedule(MaintenanceTask): + """Schedule a task in the pool scheduler. + + This task is a trampoline to allow to use a sync call (pool.run_task) + to execute an async one (pool.schedule_task). + """ + + def __init__( + self, + pool: "AsyncConnectionPool", + task: MaintenanceTask, + delay: float, + ): + super().__init__(pool) + self.task = task + self.delay = delay + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool.schedule_task(self.task, self.delay) diff --git a/psycopg_pool/psycopg_pool/py.typed b/psycopg_pool/psycopg_pool/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg_pool/psycopg_pool/py.typed diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py new file mode 100644 index 0000000..ca26007 --- /dev/null +++ b/psycopg_pool/psycopg_pool/sched.py @@ -0,0 +1,177 @@ +""" +A minimal scheduler to schedule tasks run in the future. + +Inspired to the standard library `sched.scheduler`, but designed for +multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in +front of the one currently running and `Scheduler.run()` can be left running +without any task scheduled. + +Tasks are called "Task", not "Event", here, because we actually make use of +`threading.Event` and the two would be confusing. +""" + +# Copyright (C) 2021 The Psycopg Team + +import asyncio +import logging +import threading +from time import monotonic +from heapq import heappush, heappop +from typing import Any, Callable, List, Optional, NamedTuple + +logger = logging.getLogger(__name__) + + +class Task(NamedTuple): + time: float + action: Optional[Callable[[], Any]] + + def __eq__(self, other: "Task") -> Any: # type: ignore[override] + return self.time == other.time + + def __lt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time < other.time + + def __le__(self, other: "Task") -> Any: # type: ignore[override] + return self.time <= other.time + + def __gt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time > other.time + + def __ge__(self, other: "Task") -> Any: # type: ignore[override] + return self.time >= other.time + + +class Scheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = threading.RLock() + self._event = threading.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return self.enterabs(time, action) + + def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + if not task.action: + break + try: + task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + self._event.wait(timeout=delay) + + +class AsyncScheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = asyncio.Lock() + self._event = asyncio.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return await self.enterabs(time, action) + + async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + async with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + async def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + async with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + if not task.action: + break + try: + await task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + try: + await asyncio.wait_for(self._event.wait(), delay) + except asyncio.TimeoutError: + pass diff --git a/psycopg_pool/psycopg_pool/version.py b/psycopg_pool/psycopg_pool/version.py new file mode 100644 index 0000000..fc99bbd --- /dev/null +++ b/psycopg_pool/psycopg_pool/version.py @@ -0,0 +1,13 @@ +""" +psycopg pool version file. +""" + +# Copyright (C) 2021 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ + +# STOP AND READ! if you change: +__version__ = "3.1.5" +# also change: +# - `docs/news_pool.rst` to declare this version current or unreleased diff --git a/psycopg_pool/setup.cfg b/psycopg_pool/setup.cfg new file mode 100644 index 0000000..1a3274e --- /dev/null +++ b/psycopg_pool/setup.cfg @@ -0,0 +1,45 @@ +[metadata] +name = psycopg-pool +description = Connection Pool for Psycopg +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg-pool/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +packages = find: +zip_safe = False +install_requires = + typing-extensions >= 3.10 + +[options.package_data] +psycopg_pool = py.typed diff --git a/psycopg_pool/setup.py b/psycopg_pool/setup.py new file mode 100644 index 0000000..771847d --- /dev/null +++ b/psycopg_pool/setup.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - Connection Pool +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +from setuptools import setup + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +with open("psycopg_pool/version.py") as f: + data = f.read() + m = re.search(r"""(?m)^__version__\s*=\s*['"]([^'"]+)['"]""", data) + if not m: + raise Exception(f"cannot find version in {f.name}") + version = m.group(1) + + +setup(version=version) |