diff options
Diffstat (limited to 'psycopg')
70 files changed, 19768 insertions, 0 deletions
diff --git a/psycopg/.flake8 b/psycopg/.flake8 new file mode 100644 index 0000000..67fb024 --- /dev/null +++ b/psycopg/.flake8 @@ -0,0 +1,6 @@ +[flake8] +max-line-length = 88 +ignore = W503, E203 +per-file-ignores = + # Autogenerated section + psycopg/errors.py: E125, E128, E302 diff --git a/psycopg/LICENSE.txt b/psycopg/LICENSE.txt new file mode 100644 index 0000000..0a04128 --- /dev/null +++ b/psycopg/LICENSE.txt @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/psycopg/README.rst b/psycopg/README.rst new file mode 100644 index 0000000..45eeac3 --- /dev/null +++ b/psycopg/README.rst @@ -0,0 +1,31 @@ +Psycopg 3: PostgreSQL database adapter for Python +================================================= + +Psycopg 3 is a modern implementation of a PostgreSQL adapter for Python. + +This distribution contains the pure Python package ``psycopg``. + + +Installation +------------ + +In short, run the following:: + + pip install --upgrade pip # to upgrade pip + pip install "psycopg[binary,pool]" # to install package and dependencies + +If something goes wrong, and for more information about installation, please +check out the `Installation documentation`__. + +.. __: https://www.psycopg.org/psycopg3/docs/basic/install.html# + + +Hacking +------- + +For development information check out `the project readme`__. + +.. __: https://github.com/psycopg/psycopg#readme + + +Copyright (C) 2020 The Psycopg Team diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py new file mode 100644 index 0000000..baadf30 --- /dev/null +++ b/psycopg/psycopg/__init__.py @@ -0,0 +1,110 @@ +""" +psycopg -- PostgreSQL database adapter for Python +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from . import pq # noqa: F401 import early to stabilize side effects +from . import types +from . import postgres +from ._tpc import Xid +from .copy import Copy, AsyncCopy +from ._enums import IsolationLevel +from .cursor import Cursor +from .errors import Warning, Error, InterfaceError, DatabaseError +from .errors import DataError, OperationalError, IntegrityError +from .errors import InternalError, ProgrammingError, NotSupportedError +from ._column import Column +from .conninfo import ConnectionInfo +from ._pipeline import Pipeline, AsyncPipeline +from .connection import BaseConnection, Connection, Notify +from .transaction import Rollback, Transaction, AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor, ServerCursor +from .client_cursor import AsyncClientCursor, ClientCursor +from .connection_async import AsyncConnection + +from . import dbapi20 +from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING +from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks +from .dbapi20 import Timestamp, TimestampFromTicks + +from .version import __version__ as __version__ # noqa: F401 + +# Set the logger to a quiet default, can be enabled if needed +logger = logging.getLogger("psycopg") +if logger.level == logging.NOTSET: + logger.setLevel(logging.WARNING) + +# DBAPI compliance +connect = Connection.connect +apilevel = "2.0" +threadsafety = 2 +paramstyle = "pyformat" + +# register default adapters for PostgreSQL +adapters = postgres.adapters # exposed by the package +postgres.register_default_adapters(adapters) + +# After the default ones, because these can deal with the bytea oid better +dbapi20.register_dbapi20_adapters(adapters) + +# Must come after all the types have been registered +types.array.register_all_arrays(adapters) + +# Note: defining the exported methods helps both Sphynx in documenting that +# this is the canonical place to obtain them and should be used by MyPy too, +# so that function signatures are consistent with the documentation. +__all__ = [ + "AsyncClientCursor", + "AsyncConnection", + "AsyncCopy", + "AsyncCursor", + "AsyncPipeline", + "AsyncServerCursor", + "AsyncTransaction", + "BaseConnection", + "ClientCursor", + "Column", + "Connection", + "ConnectionInfo", + "Copy", + "Cursor", + "IsolationLevel", + "Notify", + "Pipeline", + "Rollback", + "ServerCursor", + "Transaction", + "Xid", + # DBAPI exports + "connect", + "apilevel", + "threadsafety", + "paramstyle", + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + # DBAPI type constructors and singletons + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "DATETIME", + "NUMBER", + "ROWID", + "STRING", +] diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py new file mode 100644 index 0000000..a3a6ef8 --- /dev/null +++ b/psycopg/psycopg/_adapters_map.py @@ -0,0 +1,289 @@ +""" +Mapping from types/oids to Dumpers/Loaders +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import cast, TYPE_CHECKING + +from . import pq +from . import errors as e +from .abc import Dumper, Loader +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg +from ._typeinfo import TypesRegistry + +if TYPE_CHECKING: + from .connection import BaseConnection + +RV = TypeVar("RV") + + +class AdaptersMap: + r""" + Establish how types should be converted between Python and PostgreSQL in + an `~psycopg.abc.AdaptContext`. + + `!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to + define how Python types are converted to PostgreSQL, and maps OIDs to + `~psycopg.adapt.Loader` classes to establish how query results are + converted to Python. + + Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how + types are converted in that context, exposed as the + `~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows + to customise adaptation in a context without changing separated contexts. + + When a context is created from another context (for instance when a + `~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's + `!adapters` are used as template for the child's `!adapters`, so that every + cursor created from the same connection use the connection's types + configuration, but separate connections have independent mappings. + + Once created, `!AdaptersMap` are independent. This means that objects + already created are not affected if a wider scope (e.g. the global one) is + changed. + + The connections adapters are initialised using a global `!AdptersMap` + template, exposed as `psycopg.adapters`: changing such mapping allows to + customise the type mapping for every connections created afterwards. + + The object can start empty or copy from another object of the same class. + Copies are copy-on-write: if the maps are updated make a copy. This way + extending e.g. global map by a connection or a connection map from a cursor + is cheap: a copy is only made on customisation. + """ + + __module__ = "psycopg.adapt" + + types: TypesRegistry + + _dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]] + _dumpers_by_oid: List[Dict[int, Type[Dumper]]] + _loaders: List[Dict[int, Type[Loader]]] + + # Record if a dumper or loader has an optimised version. + _optimised: Dict[type, type] = {} + + def __init__( + self, + template: Optional["AdaptersMap"] = None, + types: Optional[TypesRegistry] = None, + ): + if template: + self._dumpers = template._dumpers.copy() + self._own_dumpers = _dumpers_shared.copy() + template._own_dumpers = _dumpers_shared.copy() + + self._dumpers_by_oid = template._dumpers_by_oid[:] + self._own_dumpers_by_oid = [False, False] + template._own_dumpers_by_oid = [False, False] + + self._loaders = template._loaders[:] + self._own_loaders = [False, False] + template._own_loaders = [False, False] + + self.types = TypesRegistry(template.types) + + else: + self._dumpers = {fmt: {} for fmt in PyFormat} + self._own_dumpers = _dumpers_owned.copy() + + self._dumpers_by_oid = [{}, {}] + self._own_dumpers_by_oid = [True, True] + + self._loaders = [{}, {}] + self._own_loaders = [True, True] + + self.types = types or TypesRegistry() + + # implement the AdaptContext protocol too + @property + def adapters(self) -> "AdaptersMap": + return self + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return None + + def register_dumper( + self, cls: Union[type, str, None], dumper: Type[Dumper] + ) -> None: + """ + Configure the context to use `!dumper` to convert objects of type `!cls`. + + If two dumpers with different `~Dumper.format` are registered for the + same type, the last one registered will be chosen when the query + doesn't specify a format (i.e. when the value is used with a ``%s`` + "`~PyFormat.AUTO`" placeholder). + + :param cls: The type to manage. + :param dumper: The dumper to register for `!cls`. + + If `!cls` is specified as string it will be lazy-loaded, so that it + will be possible to register it without importing it before. In this + case it should be the fully qualified name of the object (e.g. + ``"uuid.UUID"``). + + If `!cls` is None, only use the dumper when looking up using + `get_dumper_by_oid()`, which happens when we know the Postgres type to + adapt to, but not the Python type that will be adapted (e.g. in COPY + after using `~psycopg.Copy.set_types()`). + + """ + if not (cls is None or isinstance(cls, (str, type))): + raise TypeError( + f"dumpers should be registered on classes, got {cls} instead" + ) + + if _psycopg: + dumper = self._get_optimised(dumper) + + # Register the dumper both as its format and as auto + # so that the last dumper registered is used in auto (%s) format + if cls: + for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO): + if not self._own_dumpers[fmt]: + self._dumpers[fmt] = self._dumpers[fmt].copy() + self._own_dumpers[fmt] = True + + self._dumpers[fmt][cls] = dumper + + # Register the dumper by oid, if the oid of the dumper is fixed + if dumper.oid: + if not self._own_dumpers_by_oid[dumper.format]: + self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[ + dumper.format + ].copy() + self._own_dumpers_by_oid[dumper.format] = True + + self._dumpers_by_oid[dumper.format][dumper.oid] = dumper + + def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None: + """ + Configure the context to use `!loader` to convert data of oid `!oid`. + + :param oid: The PostgreSQL OID or type name to manage. + :param loader: The loar to register for `!oid`. + + If `oid` is specified as string, it refers to a type name, which is + looked up in the `types` registry. ` + + """ + if isinstance(oid, str): + oid = self.types[oid].oid + if not isinstance(oid, int): + raise TypeError(f"loaders should be registered on oid, got {oid} instead") + + if _psycopg: + loader = self._get_optimised(loader) + + fmt = loader.format + if not self._own_loaders[fmt]: + self._loaders[fmt] = self._loaders[fmt].copy() + self._own_loaders[fmt] = True + + self._loaders[fmt][oid] = loader + + def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]: + """ + Return the dumper class for the given type and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param cls: The class to adapt. + :param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`, + use the last one of the dumpers registered on `!cls`. + """ + try: + dmap = self._dumpers[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + # Look for the right class, including looking at superclasses + for scls in cls.__mro__: + if scls in dmap: + return dmap[scls] + + # If the adapter is not found, look for its name as a string + fqn = scls.__module__ + "." + scls.__qualname__ + if fqn in dmap: + # Replace the class name with the class itself + d = dmap[scls] = dmap.pop(fqn) + return d + + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__!r} using placeholder '%{format}'" + f" (format: {PyFormat(format).name})" + ) + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]: + """ + Return the dumper class for the given oid and format. + + Raise `~psycopg.ProgrammingError` if a class is not available. + + :param oid: The oid of the type to dump to. + :param format: The format to dump to. + """ + try: + dmap = self._dumpers_by_oid[format] + except KeyError: + raise ValueError(f"bad dumper format: {format}") + + try: + return dmap[oid] + except KeyError: + info = self.types.get(oid) + if info: + msg = ( + f"cannot find a dumper for type {info.name} (oid {oid})" + f" format {pq.Format(format).name}" + ) + else: + msg = ( + f"cannot find a dumper for unknown type with oid {oid}" + f" format {pq.Format(format).name}" + ) + raise e.ProgrammingError(msg) + + def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]: + """ + Return the loader class for the given oid and format. + + Return `!None` if not found. + + :param oid: The oid of the type to load. + :param format: The format to load from. + """ + return self._loaders[format].get(oid) + + @classmethod + def _get_optimised(self, cls: Type[RV]) -> Type[RV]: + """Return the optimised version of a Dumper or Loader class. + + Return the input class itself if there is no optimised version. + """ + try: + return self._optimised[cls] + except KeyError: + pass + + # Check if the class comes from psycopg.types and there is a class + # with the same name in psycopg_c._psycopg. + from psycopg import types + + if cls.__module__.startswith(types.__name__): + new = cast(Type[RV], getattr(_psycopg, cls.__name__, None)) + if new: + self._optimised[cls] = new + return new + + self._optimised[cls] = cls + return cls + + +# Micro-optimization: copying these objects is faster than creating new dicts +_dumpers_owned = dict.fromkeys(PyFormat, True) +_dumpers_shared = dict.fromkeys(PyFormat, False) diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py new file mode 100644 index 0000000..288ef1b --- /dev/null +++ b/psycopg/psycopg/_cmodule.py @@ -0,0 +1,24 @@ +""" +Simplify access to the _psycopg module +""" + +# Copyright (C) 2021 The Psycopg Team + +from typing import Optional + +from . import pq + +__version__: Optional[str] = None + +# Note: "c" must the first attempt so that mypy associates the variable the +# right module interface. It will not result Optional, but hey. +if pq.__impl__ == "c": + from psycopg_c import _psycopg as _psycopg + from psycopg_c import __version__ as __version__ # noqa: F401 +elif pq.__impl__ == "binary": + from psycopg_binary import _psycopg as _psycopg # type: ignore + from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401 +elif pq.__impl__ == "python": + _psycopg = None # type: ignore +else: + raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}") diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py new file mode 100644 index 0000000..9e4e735 --- /dev/null +++ b/psycopg/psycopg/_column.py @@ -0,0 +1,143 @@ +""" +The Column object in Cursor.description +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING +from operator import attrgetter + +if TYPE_CHECKING: + from .cursor import BaseCursor + + +class ColumnData(NamedTuple): + ftype: int + fmod: int + fsize: int + + +class Column(Sequence[Any]): + + __module__ = "psycopg" + + def __init__(self, cursor: "BaseCursor[Any, Any]", index: int): + res = cursor.pgresult + assert res + + fname = res.fname(index) + if fname: + self._name = fname.decode(cursor._encoding) + else: + # COPY_OUT results have columns but no name + self._name = f"column_{index + 1}" + + self._data = ColumnData( + ftype=res.ftype(index), + fmod=res.fmod(index), + fsize=res.fsize(index), + ) + self._type = cursor.adapters.types.get(self._data.ftype) + + _attrs = tuple( + attrgetter(attr) + for attr in """ + name type_code display_size internal_size precision scale null_ok + """.split() + ) + + def __repr__(self) -> str: + return ( + f"<Column {self.name!r}," + f" type: {self._type_display()} (oid: {self.type_code})>" + ) + + def __len__(self) -> int: + return 7 + + def _type_display(self) -> str: + parts = [] + parts.append(self._type.name if self._type else str(self.type_code)) + + mod1 = self.precision + if mod1 is None: + mod1 = self.display_size + if mod1: + parts.append(f"({mod1}") + if self.scale: + parts.append(f", {self.scale}") + parts.append(")") + + if self._type and self.type_code == self._type.array_oid: + parts.append("[]") + + return "".join(parts) + + def __getitem__(self, index: Any) -> Any: + if isinstance(index, slice): + return tuple(getter(self) for getter in self._attrs[index]) + else: + return self._attrs[index](self) + + @property + def name(self) -> str: + """The name of the column.""" + return self._name + + @property + def type_code(self) -> int: + """The numeric OID of the column.""" + return self._data.ftype + + @property + def display_size(self) -> Optional[int]: + """The field size, for :sql:`varchar(n)`, None otherwise.""" + if not self._type: + return None + + if self._type.name in ("varchar", "char"): + fmod = self._data.fmod + if fmod >= 0: + return fmod - 4 + + return None + + @property + def internal_size(self) -> Optional[int]: + """The internal field size for fixed-size types, None otherwise.""" + fsize = self._data.fsize + return fsize if fsize >= 0 else None + + @property + def precision(self) -> Optional[int]: + """The number of digits for fixed precision types.""" + if not self._type: + return None + + dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval") + if self._type.name == "numeric": + fmod = self._data.fmod + if fmod >= 0: + return fmod >> 16 + + elif self._type.name in dttypes: + fmod = self._data.fmod + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def scale(self) -> Optional[int]: + """The number of digits after the decimal point if available.""" + if self._type and self._type.name == "numeric": + fmod = self._data.fmod - 4 + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def null_ok(self) -> Optional[bool]: + """Always `!None`""" + return None diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py new file mode 100644 index 0000000..7dbae79 --- /dev/null +++ b/psycopg/psycopg/_compat.py @@ -0,0 +1,72 @@ +""" +compatibility functions for different Python versions +""" + +# Copyright (C) 2021 The Psycopg Team + +import sys +import asyncio +from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar + +# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it. +# For this raisin it must be imported directly from typing_extension where used. +# See https://github.com/microsoft/pyright/issues/4197 +from typing_extensions import TypeAlias + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +T = TypeVar("T") +FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] + +if sys.version_info >= (3, 8): + create_task = asyncio.create_task + from math import prod + +else: + + def create_task( + coro: FutureT[T], name: Optional[str] = None + ) -> "asyncio.Future[T]": + return asyncio.create_task(coro) + + from functools import reduce + + def prod(seq: Sequence[int]) -> int: + return reduce(int.__mul__, seq, 1) + + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo + from functools import cache + from collections import Counter, deque as Deque +else: + from typing import Counter, Deque + from functools import lru_cache + from backports.zoneinfo import ZoneInfo + + cache = lru_cache(maxsize=None) + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +if sys.version_info >= (3, 11): + from typing import LiteralString +else: + from typing_extensions import LiteralString + +__all__ = [ + "Counter", + "Deque", + "LiteralString", + "Protocol", + "TypeGuard", + "ZoneInfo", + "cache", + "create_task", + "prod", +] diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py new file mode 100644 index 0000000..1e146ba --- /dev/null +++ b/psycopg/psycopg/_dns.py @@ -0,0 +1,223 @@ +# type: ignore # dnspython is currently optional and mypy fails if missing +""" +DNS query support +""" + +# Copyright (C) 2021 The Psycopg Team + +import os +import re +import warnings +from random import randint +from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence +from typing import TYPE_CHECKING +from collections import defaultdict + +try: + from dns.resolver import Resolver, Cache + from dns.asyncresolver import Resolver as AsyncResolver + from dns.exception import DNSException +except ImportError: + raise ImportError( + "the module psycopg._dns requires the package 'dnspython' installed" + ) + +from . import errors as e +from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_ + +if TYPE_CHECKING: + from dns.rdtypes.IN.SRV import SRV + +resolver = Resolver() +resolver.cache = Cache() + +async_resolver = AsyncResolver() +async_resolver.cache = Cache() + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + .. deprecated:: 3.1 + The use of this function is not necessary anymore, because + `psycopg.AsyncConnection.connect()` performs non-blocking name + resolution automatically. + """ + warnings.warn( + "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore", + DeprecationWarning, + ) + return await resolve_hostaddr_async_(params) + + +def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]: + """Apply SRV DNS lookup as defined in :RFC:`2782`.""" + return Rfc2782Resolver().resolve(params) + + +async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]: + """Async equivalent of `resolve_srv()`.""" + return await Rfc2782Resolver().resolve_async(params) + + +class HostPort(NamedTuple): + host: str + port: str + totry: bool = False + target: Optional[str] = None + + +class Rfc2782Resolver: + """Implement SRV RR Resolution as per RFC 2782 + + The class is organised to minimise code duplication between the sync and + the async paths. + """ + + re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)") + + def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(self._resolve_srv(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Update the parameters host and port after SRV lookup.""" + attempts = self._get_attempts(params) + if not attempts: + return params + + hps = [] + for hp in attempts: + if hp.totry: + hps.extend(await self._resolve_srv_async(hp)) + else: + hps.append(hp) + + return self._return_params(params, hps) + + def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]: + """ + Return the list of host, and for each host if SRV lookup must be tried. + + Return an empty list if no lookup is requested. + """ + # If hostaddr is defined don't do any resolution. + if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")): + return [] + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") + + if len(ports_in) == 1: + # If only one port is specified, it applies to all the hosts. + ports_in *= len(hosts_in) + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + + out = [] + srv_found = False + for host, port in zip(hosts_in, ports_in): + m = self.re_srv_rr.match(host) + if m or port.lower() == "srv": + srv_found = True + target = m.group("target") if m else None + hp = HostPort(host=host, port=port, totry=True, target=target) + else: + hp = HostPort(host=host, port=port) + out.append(hp) + + return out if srv_found else [] + + def _resolve_srv(self, hp: HostPort) -> List[HostPort]: + try: + ans = resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]: + try: + ans = await async_resolver.resolve(hp.host, "SRV") + except DNSException: + ans = () + return self._get_solved_entries(hp, ans) + + def _get_solved_entries( + self, hp: HostPort, entries: "Sequence[SRV]" + ) -> List[HostPort]: + if not entries: + # No SRV entry found. Delegate the libpq a QNAME=target lookup + if hp.target and hp.port.lower() != "srv": + return [HostPort(host=hp.target, port=hp.port)] + else: + return [] + + # If there is precisely one SRV RR, and its Target is "." (the root + # domain), abort. + if len(entries) == 1 and str(entries[0].target) == ".": + return [] + + return [ + HostPort(host=str(entry.target).rstrip("."), port=str(entry.port)) + for entry in self.sort_rfc2782(entries) + ] + + def _return_params( + self, params: Dict[str, Any], hps: List[HostPort] + ) -> Dict[str, Any]: + if not hps: + # Nothing found, we ended up with an empty list + raise e.OperationalError("no host found after SRV RR lookup") + + out = params.copy() + out["host"] = ",".join(hp.host for hp in hps) + out["port"] = ",".join(str(hp.port) for hp in hps) + return out + + def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]": + """ + Implement the priority/weight ordering defined in RFC 2782. + """ + # Divide the entries by priority: + priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list) + out: "List[SRV]" = [] + for entry in ans: + priorities[entry.priority].append(entry) + + for pri, entries in sorted(priorities.items()): + if len(entries) == 1: + out.append(entries[0]) + continue + + entries.sort(key=lambda ent: ent.weight) + total_weight = sum(ent.weight for ent in entries) + while entries: + r = randint(0, total_weight) + csum = 0 + for i, ent in enumerate(entries): + csum += ent.weight + if csum >= r: + break + out.append(ent) + total_weight -= ent.weight + del entries[i] + + return out diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py new file mode 100644 index 0000000..c584b26 --- /dev/null +++ b/psycopg/psycopg/_encodings.py @@ -0,0 +1,170 @@ +""" +Mappings between PostgreSQL and Python encodings. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import string +import codecs +from typing import Any, Dict, Optional, TYPE_CHECKING + +from .pq._enums import ConnStatus +from .errors import NotSupportedError +from ._compat import cache + +if TYPE_CHECKING: + from .pq.abc import PGconn + from .connection import BaseConnection + +OK = ConnStatus.OK + + +_py_codecs = { + "BIG5": "big5", + "EUC_CN": "gb2312", + "EUC_JIS_2004": "euc_jis_2004", + "EUC_JP": "euc_jp", + "EUC_KR": "euc_kr", + # "EUC_TW": not available in Python + "GB18030": "gb18030", + "GBK": "gbk", + "ISO_8859_5": "iso8859-5", + "ISO_8859_6": "iso8859-6", + "ISO_8859_7": "iso8859-7", + "ISO_8859_8": "iso8859-8", + "JOHAB": "johab", + "KOI8R": "koi8-r", + "KOI8U": "koi8-u", + "LATIN1": "iso8859-1", + "LATIN10": "iso8859-16", + "LATIN2": "iso8859-2", + "LATIN3": "iso8859-3", + "LATIN4": "iso8859-4", + "LATIN5": "iso8859-9", + "LATIN6": "iso8859-10", + "LATIN7": "iso8859-13", + "LATIN8": "iso8859-14", + "LATIN9": "iso8859-15", + # "MULE_INTERNAL": not available in Python + "SHIFT_JIS_2004": "shift_jis_2004", + "SJIS": "shift_jis", + # this actually means no encoding, see PostgreSQL docs + # it is special-cased by the text loader. + "SQL_ASCII": "ascii", + "UHC": "cp949", + "UTF8": "utf-8", + "WIN1250": "cp1250", + "WIN1251": "cp1251", + "WIN1252": "cp1252", + "WIN1253": "cp1253", + "WIN1254": "cp1254", + "WIN1255": "cp1255", + "WIN1256": "cp1256", + "WIN1257": "cp1257", + "WIN1258": "cp1258", + "WIN866": "cp866", + "WIN874": "cp874", +} + +py_codecs: Dict[bytes, str] = {} +py_codecs.update((k.encode(), v) for k, v in _py_codecs.items()) + +# Add an alias without underscore, for lenient lookups +py_codecs.update( + (k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k +) + +pg_codecs = {v: k.encode() for k, v in _py_codecs.items()} + + +def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str: + """ + Return the Python encoding name of a psycopg connection. + + Default to utf8 if the connection has no encoding info. + """ + if not conn or conn.closed: + return "utf-8" + + pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def pgconn_encoding(pgconn: "PGconn") -> str: + """ + Return the Python encoding name of a libpq connection. + + Default to utf8 if the connection has no encoding info. + """ + if pgconn.status != OK: + return "utf-8" + + pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + +def conninfo_encoding(conninfo: str) -> str: + """ + Return the Python encoding name passed in a conninfo string. Default to utf8. + + Because the input is likely to come from the user and not normalised by the + server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars). + """ + from .conninfo import conninfo_to_dict + + params = conninfo_to_dict(conninfo) + pgenc = params.get("client_encoding") + if pgenc: + try: + return pg2pyenc(pgenc.encode()) + except NotSupportedError: + pass + + return "utf-8" + + +@cache +def py2pgenc(name: str) -> bytes: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise LookupError if the Python encoding is unknown. + """ + return pg_codecs[codecs.lookup(name).name] + + +@cache +def pg2pyenc(name: bytes) -> str: + """Convert a Python encoding name to PostgreSQL encoding name. + + Raise NotSupportedError if the PostgreSQL encoding is not supported by + Python. + """ + try: + return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()] + except KeyError: + sname = name.decode("utf8", "replace") + raise NotSupportedError(f"codec not available in Python: {sname!r}") + + +def _as_python_identifier(s: str, prefix: str = "f") -> str: + """ + Reduce a string to a valid Python identifier. + + Replace all non-valid chars with '_' and prefix the value with `!prefix` if + the first letter is an '_'. + """ + if not s.isidentifier(): + if s[0] in "1234567890": + s = prefix + s + if not s.isidentifier(): + s = _re_clean.sub("_", s) + # namedtuple fields cannot start with underscore. So... + if s[0] == "_": + s = prefix + s + return s + + +_re_clean = re.compile( + f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]" +) diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py new file mode 100644 index 0000000..a7cb78d --- /dev/null +++ b/psycopg/psycopg/_enums.py @@ -0,0 +1,79 @@ +""" +Enum values for psycopg + +These values are defined by us and are not necessarily dependent on +libpq-defined enums. +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import Enum, IntEnum +from selectors import EVENT_READ, EVENT_WRITE + +from . import pq + + +class Wait(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class Ready(IntEnum): + R = EVENT_READ + W = EVENT_WRITE + RW = EVENT_READ | EVENT_WRITE + + +class PyFormat(str, Enum): + """ + Enum representing the format wanted for a query argument. + + The value `AUTO` allows psycopg to choose the best format for a certain + parameter. + """ + + __module__ = "psycopg.adapt" + + AUTO = "s" + """Automatically chosen (``%s`` placeholder).""" + TEXT = "t" + """Text parameter (``%t`` placeholder).""" + BINARY = "b" + """Binary parameter (``%b`` placeholder).""" + + @classmethod + def from_pq(cls, fmt: pq.Format) -> "PyFormat": + return _pg2py[fmt] + + @classmethod + def as_pq(cls, fmt: "PyFormat") -> pq.Format: + return _py2pg[fmt] + + +class IsolationLevel(IntEnum): + """ + Enum representing the isolation level for a transaction. + """ + + __module__ = "psycopg" + + READ_UNCOMMITTED = 1 + """:sql:`READ UNCOMMITTED` isolation level.""" + READ_COMMITTED = 2 + """:sql:`READ COMMITTED` isolation level.""" + REPEATABLE_READ = 3 + """:sql:`REPEATABLE READ` isolation level.""" + SERIALIZABLE = 4 + """:sql:`SERIALIZABLE` isolation level.""" + + +_py2pg = { + PyFormat.TEXT: pq.Format.TEXT, + PyFormat.BINARY: pq.Format.BINARY, +} + +_pg2py = { + pq.Format.TEXT: PyFormat.TEXT, + pq.Format.BINARY: PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py new file mode 100644 index 0000000..c818d86 --- /dev/null +++ b/psycopg/psycopg/_pipeline.py @@ -0,0 +1,288 @@ +""" +commands pipeline management +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +from types import TracebackType +from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from .abc import PipelineCommand, PQGen +from ._compat import Deque +from ._encodings import pgconn_encoding +from ._preparing import Key, Prepare +from .generators import pipeline_communicate, fetch_many, send + +if TYPE_CHECKING: + from .pq.abc import PGresult + from .cursor import BaseCursor + from .connection import BaseConnection, Connection + from .connection_async import AsyncConnection + + +PendingResult: TypeAlias = Union[ + None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]] +] + +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED +BAD = pq.ConnStatus.BAD + +ACTIVE = pq.TransactionStatus.ACTIVE + +logger = logging.getLogger("psycopg") + + +class BasePipeline: + + command_queue: Deque[PipelineCommand] + result_queue: Deque[PendingResult] + _is_supported: Optional[bool] = None + + def __init__(self, conn: "BaseConnection[Any]") -> None: + self._conn = conn + self.pgconn = conn.pgconn + self.command_queue = Deque[PipelineCommand]() + self.result_queue = Deque[PendingResult]() + self.level = 0 + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._conn.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def status(self) -> pq.PipelineStatus: + return pq.PipelineStatus(self.pgconn.pipeline_status) + + @classmethod + def is_supported(cls) -> bool: + """Return `!True` if the psycopg libpq wrapper supports pipeline mode.""" + if BasePipeline._is_supported is None: + BasePipeline._is_supported = not cls._not_supported_reason() + return BasePipeline._is_supported + + @classmethod + def _not_supported_reason(cls) -> str: + """Return the reason why the pipeline mode is not supported. + + Return an empty string if pipeline mode is supported. + """ + # Support only depends on the libpq functions available in the pq + # wrapper, not on the database version. + if pq.version() < 140000: + return ( + f"libpq too old {pq.version()};" + " v14 or greater required for pipeline mode" + ) + + if pq.__build_version__ < 140000: + return ( + f"libpq too old: module built for {pq.__build_version__};" + " v14 or greater required for pipeline mode" + ) + + return "" + + def _enter_gen(self) -> PQGen[None]: + if not self.is_supported(): + raise e.NotSupportedError( + f"pipeline mode not supported: {self._not_supported_reason()}" + ) + if self.level == 0: + self.pgconn.enter_pipeline_mode() + elif self.command_queue or self.pgconn.transaction_status == ACTIVE: + # Nested pipeline case. + # Transaction might be ACTIVE when the pipeline uses an "implicit + # transaction", typically in autocommit mode. But when entering a + # Psycopg transaction(), we expect the IDLE state. By sync()-ing, + # we make sure all previous commands are completed and the + # transaction gets back to IDLE. + yield from self._sync_gen() + self.level += 1 + + def _exit(self, exc: Optional[BaseException]) -> None: + self.level -= 1 + if self.level == 0 and self.pgconn.status != BAD: + try: + self.pgconn.exit_pipeline_mode() + except e.OperationalError as exc2: + # Notice that this error might be pretty irrecoverable. It + # happens on COPY, for instance: even if sync succeeds, exiting + # fails with "cannot exit pipeline mode with uncollected results" + if exc: + logger.warning("error ignored exiting %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + + def _sync_gen(self) -> PQGen[None]: + self._enqueue_sync() + yield from self._communicate_gen() + yield from self._fetch_gen(flush=False) + + def _exit_gen(self) -> PQGen[None]: + """ + Exit current pipeline by sending a Sync and fetch back all remaining results. + """ + try: + self._enqueue_sync() + yield from self._communicate_gen() + finally: + # No need to force flush since we emitted a sync just before. + yield from self._fetch_gen(flush=False) + + def _communicate_gen(self) -> PQGen[None]: + """Communicate with pipeline to send commands and possibly fetch + results, which are then processed. + """ + fetched = yield from pipeline_communicate(self.pgconn, self.command_queue) + to_process = [(self.result_queue.popleft(), results) for results in fetched] + for queued, results in to_process: + self._process_results(queued, results) + + def _fetch_gen(self, *, flush: bool) -> PQGen[None]: + """Fetch available results from the connection and process them with + pipeline queued items. + + If 'flush' is True, a PQsendFlushRequest() is issued in order to make + sure results can be fetched. Otherwise, the caller may emit a + PQpipelineSync() call to ensure the output buffer gets flushed before + fetching. + """ + if not self.result_queue: + return + + if flush: + self.pgconn.send_flush_request() + yield from send(self.pgconn) + + to_process = [] + while self.result_queue: + results = yield from fetch_many(self.pgconn) + if not results: + # No more results to fetch, but there may still be pending + # commands. + break + queued = self.result_queue.popleft() + to_process.append((queued, results)) + + for queued, results in to_process: + self._process_results(queued, results) + + def _process_results( + self, queued: PendingResult, results: List["PGresult"] + ) -> None: + """Process a results set fetched from the current pipeline. + + This matches 'results' with its respective element in the pipeline + queue. For commands (None value in the pipeline queue), results are + checked directly. For prepare statement creation requests, update the + cache. Otherwise, results are attached to their respective cursor. + """ + if queued is None: + (result,) = results + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + elif result.status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + cursor, prepinfo = queued + cursor._set_results_from_pipeline(results) + if prepinfo: + key, prep, name = prepinfo + # Update the prepare state of the query. + cursor._conn._prepared.validate(key, prep, name, results) + + def _enqueue_sync(self) -> None: + """Enqueue a PQpipelineSync() command.""" + self.command_queue.append(self.pgconn.pipeline_sync) + self.result_queue.append(None) + + +class Pipeline(BasePipeline): + """Handler for connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "Connection[Any]" + _Self = TypeVar("_Self", bound="Pipeline") + + def __init__(self, conn: "Connection[Any]") -> None: + super().__init__(conn) + + def sync(self) -> None: + """Sync the pipeline, send any pending command and receive and process + all available results. + """ + try: + with self._conn.lock: + self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + with self._conn.lock: + self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) + + +class AsyncPipeline(BasePipeline): + """Handler for async connection in pipeline mode.""" + + __module__ = "psycopg" + _conn: "AsyncConnection[Any]" + _Self = TypeVar("_Self", bound="AsyncPipeline") + + def __init__(self, conn: "AsyncConnection[Any]") -> None: + super().__init__(conn) + + async def sync(self) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._sync_gen()) + except e.Error as ex: + raise ex.with_traceback(None) + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + try: + async with self._conn.lock: + await self._conn.wait(self._exit_gen()) + except Exception as exc2: + # Don't clobber an exception raised in the block with this one + if exc_val: + logger.warning("error ignored terminating %r: %s", self, exc2) + else: + raise exc2.with_traceback(None) + finally: + self._exit(exc_val) diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py new file mode 100644 index 0000000..f60c0cb --- /dev/null +++ b/psycopg/psycopg/_preparing.py @@ -0,0 +1,198 @@ +""" +Support for prepared statements +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, auto +from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING +from collections import OrderedDict +from typing_extensions import TypeAlias + +from . import pq +from ._compat import Deque +from ._queries import PostgresQuery + +if TYPE_CHECKING: + from .pq.abc import PGresult + +Key: TypeAlias = Tuple[bytes, Tuple[int, ...]] + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + + +class Prepare(IntEnum): + NO = auto() + YES = auto() + SHOULD = auto() + + +class PrepareManager: + # Number of times a query is executed before it is prepared. + prepare_threshold: Optional[int] = 5 + + # Maximum number of prepared statements on the connection. + prepared_max: int = 100 + + def __init__(self) -> None: + # Map (query, types) to the number of times the query was seen. + self._counts: OrderedDict[Key, int] = OrderedDict() + + # Map (query, types) to the name of the statement if prepared. + self._names: OrderedDict[Key, bytes] = OrderedDict() + + # Counter to generate prepared statements names + self._prepared_idx = 0 + + self._maint_commands = Deque[bytes]() + + @staticmethod + def key(query: PostgresQuery) -> Key: + return (query.query, query.types) + + def get( + self, query: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + """ + Check if a query is prepared, tell back whether to prepare it. + """ + if prepare is False or self.prepare_threshold is None: + # The user doesn't want this query to be prepared + return Prepare.NO, b"" + + key = self.key(query) + name = self._names.get(key) + if name: + # The query was already prepared in this session + return Prepare.YES, name + + count = self._counts.get(key, 0) + if count >= self.prepare_threshold or prepare: + # The query has been executed enough times and needs to be prepared + name = f"_pg3_{self._prepared_idx}".encode() + self._prepared_idx += 1 + return Prepare.SHOULD, name + else: + # The query is not to be prepared yet + return Prepare.NO, b"" + + def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool: + """Check if we need to discard our entire state: it should happen on + rollback or on dropping objects, because the same object may get + recreated and postgres would fail internal lookups. + """ + if self._names or prep == Prepare.SHOULD: + for result in results: + if result.status != COMMAND_OK: + continue + cmdstat = result.command_status + if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"): + return self.clear() + return False + + @staticmethod + def _check_results(results: Sequence["PGresult"]) -> bool: + """Return False if 'results' are invalid for prepared statement cache.""" + if len(results) != 1: + # We cannot prepare a multiple statement + return False + + status = results[0].status + if COMMAND_OK != status != TUPLES_OK: + # We don't prepare failed queries or other weird results + return False + + return True + + def _rotate(self) -> None: + """Evict an old value from the cache. + + If it was prepared, deallocate it. Do it only once: if the cache was + resized, deallocate gradually. + """ + if len(self._counts) > self.prepared_max: + self._counts.popitem(last=False) + + if len(self._names) > self.prepared_max: + name = self._names.popitem(last=False)[1] + self._maint_commands.append(b"DEALLOCATE " + name) + + def maybe_add_to_cache( + self, query: PostgresQuery, prep: Prepare, name: bytes + ) -> Optional[Key]: + """Handle 'query' for possible addition to the cache. + + If a new entry has been added, return its key. Return None otherwise + (meaning the query is already in cache or cache is not enabled). + + Note: This method is only called in pipeline mode. + """ + # don't do anything if prepared statements are disabled + if self.prepare_threshold is None: + return None + + key = self.key(query) + if key in self._counts: + if prep is Prepare.SHOULD: + del self._counts[key] + self._names[key] = name + else: + self._counts[key] += 1 + self._counts.move_to_end(key) + return None + + elif key in self._names: + self._names.move_to_end(key) + return None + + else: + if prep is Prepare.SHOULD: + self._names[key] = name + else: + self._counts[key] = 1 + return key + + def validate( + self, + key: Key, + prep: Prepare, + name: bytes, + results: Sequence["PGresult"], + ) -> None: + """Validate cached entry with 'key' by checking query 'results'. + + Possibly return a command to perform maintenance on database side. + + Note: this method is only called in pipeline mode. + """ + if self._should_discard(prep, results): + return + + if not self._check_results(results): + self._names.pop(key, None) + self._counts.pop(key, None) + else: + self._rotate() + + def clear(self) -> bool: + """Clear the cache of the maintenance commands. + + Clear the internal state and prepare a command to clear the state of + the server. + """ + self._counts.clear() + if self._names: + self._names.clear() + self._maint_commands.clear() + self._maint_commands.append(b"DEALLOCATE ALL") + return True + else: + return False + + def get_maintenance_commands(self) -> Iterator[bytes]: + """ + Iterate over the commands needed to align the server state to our state + """ + while self._maint_commands: + yield self._maint_commands.popleft() diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py new file mode 100644 index 0000000..2a7554c --- /dev/null +++ b/psycopg/psycopg/_queries.py @@ -0,0 +1,375 @@ +""" +Utility module to manipulate queries +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional +from typing import Sequence, Tuple, Union, TYPE_CHECKING +from functools import lru_cache + +from . import pq +from . import errors as e +from .sql import Composable +from .abc import Buffer, Query, Params +from ._enums import PyFormat +from ._encodings import conn_encoding + +if TYPE_CHECKING: + from .abc import Transformer + + +class QueryPart(NamedTuple): + pre: bytes + item: Union[int, str] + format: PyFormat + + +class PostgresQuery: + """ + Helper to convert a Python query and parameters into Postgres format. + """ + + __slots__ = """ + query params types formats + _tx _want_formats _parts _encoding _order + """.split() + + def __init__(self, transformer: "Transformer"): + self._tx = transformer + + self.params: Optional[Sequence[Optional[Buffer]]] = None + # these are tuples so they can be used as keys e.g. in prepared stmts + self.types: Tuple[int, ...] = () + + # The format requested by the user and the ones to really pass Postgres + self._want_formats: Optional[List[PyFormat]] = None + self.formats: Optional[Sequence[pq.Format]] = None + + self._encoding = conn_encoding(transformer.connection) + self._parts: List[QueryPart] + self.query = b"" + self._order: Optional[List[str]] = None + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + ( + self.query, + self._want_formats, + self._order, + self._parts, + ) = _query2pg(bquery, self._encoding) + else: + self.query = bquery + self._want_formats = self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + assert self._want_formats is not None + self.params = self._tx.dump_sequence(params, self._want_formats) + self.types = self._tx.types or () + self.formats = self._tx.formats + else: + self.params = None + self.types = () + self.formats = None + + +class PostgresClientQuery(PostgresQuery): + """ + PostgresQuery subclass merging query and arguments client-side. + """ + + __slots__ = ("template",) + + def convert(self, query: Query, vars: Optional[Params]) -> None: + """ + Set up the query and parameters to convert. + + The results of this function can be obtained accessing the object + attributes (`query`, `params`, `types`, `formats`). + """ + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + if vars is not None: + (self.template, self._order, self._parts) = _query2pg_client( + bquery, self._encoding + ) + else: + self.query = bquery + self._order = None + + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + """ + Process a new set of variables on the query processed by `convert()`. + + This method updates `params` and `types`. + """ + if vars is not None: + params = _validate_and_reorder_params(self._parts, vars, self._order) + self.params = tuple( + self._tx.as_literal(p) if p is not None else b"NULL" for p in params + ) + self.query = self.template % self.params + else: + self.params = None + + +@lru_cache() +def _query2pg( + query: bytes, encoding: str +) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into something Postgres understands. + + - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres + format (``$1``, ``$2``) + - placeholders can be %s, %t, or %b (auto, text or binary) + - return ``query`` (bytes), ``formats`` (list of formats) ``order`` + (sequence of names used in the query, in the position they appear) + ``parts`` (splits of queries and placeholders). + """ + parts = _split_query(query, encoding) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + formats = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"$%d" % (part.item + 1)) + formats.append(part.format) + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"$%d" % (len(seen) + 1) + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + formats.append(part.format) + else: + if seen[part.item][1] != part.format: + raise e.ProgrammingError( + f"placeholder '{part.item}' cannot have different formats" + ) + chunks.append(seen[part.item][0]) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), formats, order, parts + + +@lru_cache() +def _query2pg_client( + query: bytes, encoding: str +) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into a template to perform client-side binding + """ + parts = _split_query(query, encoding, collapse_double_percent=False) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"%s") + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"%s" + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + else: + chunks.append(seen[part.item][0]) + order.append(part.item) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), order, parts + + +def _validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] +) -> Sequence[Any]: + """ + Verify the compatibility between a query and a set of params. + """ + # Try concrete types, then abstract types + t = type(vars) + if t is list or t is tuple: + sequence = True + elif t is dict: + sequence = False + elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + sequence = True + elif isinstance(vars, Mapping): + sequence = False + else: + raise TypeError( + "query parameters should be a sequence or a mapping," + f" got {type(vars).__name__}" + ) + + if sequence: + if len(vars) != len(parts) - 1: + raise e.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0].item, int): + raise TypeError("named placeholders require a mapping of parameters") + return vars # type: ignore[return-value] + + else: + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + try: + return [vars[item] for item in order or ()] # type: ignore[call-overload] + except KeyError: + raise e.ProgrammingError( + "query parameter missing:" + f" {', '.join(sorted(i for i in order or () if i not in vars))}" + ) + + +_re_placeholder = re.compile( + rb"""(?x) + % # a literal % + (?: + (?: + \( ([^)]+) \) # or a name in (braces) + . # followed by a format + ) + | + (?:.) # or any char, really + ) + """ +) + + +def _split_query( + query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True +) -> List[QueryPart]: + parts: List[Tuple[bytes, Optional[Match[bytes]]]] = [] + cur = 0 + + # pairs [(fragment, match], with the last match None + m = None + for m in _re_placeholder.finditer(query): + pre = query[cur : m.span(0)[0]] + parts.append((pre, m)) + cur = m.span(0)[1] + if m: + parts.append((query[cur:], None)) + else: + parts.append((query, None)) + + rv = [] + + # drop the "%%", validate + i = 0 + phtype = None + while i < len(parts): + pre, m = parts[i] + if m is None: + # last part + rv.append(QueryPart(pre, 0, PyFormat.AUTO)) + break + + ph = m.group(0) + if ph == b"%%": + # unescape '%%' to '%' if necessary, then merge the parts + if collapse_double_percent: + ph = b"%" + pre1, m1 = parts[i + 1] + parts[i + 1] = (pre + ph + pre1, m1) + del parts[i] + continue + + if ph == b"%(": + raise e.ProgrammingError( + "incomplete placeholder:" + f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'" + ) + elif ph == b"% ": + # explicit messasge for a typical error + raise e.ProgrammingError( + "incomplete placeholder: '%'; if you want to use '%' as an" + " operator you can double it up, i.e. use '%%'" + ) + elif ph[-1:] not in b"sbt": + raise e.ProgrammingError( + "only '%s', '%b', '%t' are allowed as placeholders, got" + f" '{m.group(0).decode(encoding)}'" + ) + + # Index or name + item: Union[int, str] + item = m.group(1).decode(encoding) if m.group(1) else i + + if not phtype: + phtype = type(item) + elif phtype is not type(item): + raise e.ProgrammingError( + "positional and named placeholders cannot be mixed" + ) + + format = _ph_to_fmt[ph[-1:]] + rv.append(QueryPart(pre, item, format)) + i += 1 + + return rv + + +_ph_to_fmt = { + b"s": PyFormat.AUTO, + b"t": PyFormat.TEXT, + b"b": PyFormat.BINARY, +} diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py new file mode 100644 index 0000000..28a6084 --- /dev/null +++ b/psycopg/psycopg/_struct.py @@ -0,0 +1,57 @@ +""" +Utility functions to deal with binary structs. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from typing import Callable, cast, Optional, Tuple +from typing_extensions import TypeAlias + +from .abc import Buffer +from . import errors as e +from ._compat import Protocol + +PackInt: TypeAlias = Callable[[int], bytes] +UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]] +PackFloat: TypeAlias = Callable[[float], bytes] +UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]] + + +class UnpackLen(Protocol): + def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: + ... + + +pack_int2 = cast(PackInt, struct.Struct("!h").pack) +pack_uint2 = cast(PackInt, struct.Struct("!H").pack) +pack_int4 = cast(PackInt, struct.Struct("!i").pack) +pack_uint4 = cast(PackInt, struct.Struct("!I").pack) +pack_int8 = cast(PackInt, struct.Struct("!q").pack) +pack_float4 = cast(PackFloat, struct.Struct("!f").pack) +pack_float8 = cast(PackFloat, struct.Struct("!d").pack) + +unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack) +unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack) +unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack) +unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack) +unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack) +unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack) +unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack) + +_struct_len = struct.Struct("!i") +pack_len = cast(Callable[[int], bytes], _struct_len.pack) +unpack_len = cast(UnpackLen, _struct_len.unpack_from) + + +def pack_float4_bug_304(x: float) -> bytes: + raise e.InterfaceError( + "cannot dump Float4: Python affected by bug #304. Note that the psycopg-c" + " and psycopg-binary packages are not affected by this issue." + " See https://github.com/psycopg/psycopg/issues/304" + ) + + +# If issue #304 is detected, raise an error instead of dumping wrong data. +if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"): + pack_float4 = pack_float4_bug_304 diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py new file mode 100644 index 0000000..3528188 --- /dev/null +++ b/psycopg/psycopg/_tpc.py @@ -0,0 +1,116 @@ +""" +psycopg two-phase commit support +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +import datetime as dt +from base64 import b64encode, b64decode +from typing import Optional, Union +from dataclasses import dataclass, replace + +_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$") + + +@dataclass(frozen=True) +class Xid: + """A two-phase commit transaction identifier. + + The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`, + `bqual`). + + """ + + format_id: Optional[int] + gtrid: str + bqual: Optional[str] + prepared: Optional[dt.datetime] = None + owner: Optional[str] = None + database: Optional[str] = None + + @classmethod + def from_string(cls, s: str) -> "Xid": + """Try to parse an XA triple from the string. + + This may fail for several reasons. In such case return an unparsed Xid. + """ + try: + return cls._parse_string(s) + except Exception: + return Xid(None, s, None) + + def __str__(self) -> str: + return self._as_tid() + + def __len__(self) -> int: + return 3 + + def __getitem__(self, index: int) -> Union[int, str, None]: + return (self.format_id, self.gtrid, self.bqual)[index] + + @classmethod + def _parse_string(cls, s: str) -> "Xid": + m = _re_xid.match(s) + if not m: + raise ValueError("bad Xid format") + + format_id = int(m.group(1)) + gtrid = b64decode(m.group(2)).decode() + bqual = b64decode(m.group(3)).decode() + return cls.from_parts(format_id, gtrid, bqual) + + @classmethod + def from_parts( + cls, format_id: Optional[int], gtrid: str, bqual: Optional[str] + ) -> "Xid": + if format_id is not None: + if bqual is None: + raise TypeError("if format_id is specified, bqual must be too") + if not 0 <= format_id < 0x80000000: + raise ValueError("format_id must be a non-negative 32-bit integer") + if len(bqual) > 64: + raise ValueError("bqual must be not longer than 64 chars") + if len(gtrid) > 64: + raise ValueError("gtrid must be not longer than 64 chars") + + elif bqual is None: + raise TypeError("if format_id is None, bqual must be None too") + + return Xid(format_id, gtrid, bqual) + + def _as_tid(self) -> str: + """ + Return the PostgreSQL transaction_id for this XA xid. + + PostgreSQL wants just a string, while the DBAPI supports the XA + standard and thus a triple. We use the same conversion algorithm + implemented by JDBC in order to allow some form of interoperation. + + see also: the pgjdbc implementation + http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/ + postgresql/xa/RecoveredXid.java?rev=1.2 + """ + if self.format_id is None or self.bqual is None: + # Unparsed xid: return the gtrid. + return self.gtrid + + # XA xid: mash together the components. + egtrid = b64encode(self.gtrid.encode()).decode() + ebqual = b64encode(self.bqual.encode()).decode() + + return f"{self.format_id}_{egtrid}_{ebqual}" + + @classmethod + def _get_recover_query(cls) -> str: + return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts" + + @classmethod + def _from_record( + cls, gid: str, prepared: dt.datetime, owner: str, database: str + ) -> "Xid": + xid = Xid.from_string(gid) + return replace(xid, prepared=prepared, owner=owner, database=database) + + +Xid.__module__ = "psycopg" diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py new file mode 100644 index 0000000..19bd6ae --- /dev/null +++ b/psycopg/psycopg/_transform.py @@ -0,0 +1,350 @@ +""" +Helper object to transform values between Python and PostgreSQL +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import DefaultDict, TYPE_CHECKING +from collections import defaultdict +from typing_extensions import TypeAlias + +from . import pq +from . import postgres +from . import errors as e +from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType +from .rows import Row, RowMaker +from .postgres import INVALID_OID, TEXT_OID +from ._encodings import pgconn_encoding + +if TYPE_CHECKING: + from .abc import Dumper, Loader + from .adapt import AdaptersMap + from .pq.abc import PGresult + from .connection import BaseConnection + +DumperCache: TypeAlias = Dict[DumperKey, "Dumper"] +OidDumperCache: TypeAlias = Dict[int, "Dumper"] +LoaderCache: TypeAlias = Dict[int, "Loader"] + +TEXT = pq.Format.TEXT +PY_TEXT = PyFormat.TEXT + + +class Transformer(AdaptContext): + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that attributes + such as the server version or the connection encoding will not change. The + object have its state so adapting several values of the same type can be + optimised. + + """ + + __module__ = "psycopg.adapt" + + __slots__ = """ + types formats + _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid + _oid_dumpers _oid_types _row_dumpers _row_loaders + """.split() + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + _adapters: "AdaptersMap" + _pgresult: Optional["PGresult"] + _none_oid: int + + def __init__(self, context: Optional[AdaptContext] = None): + self._pgresult = self.types = self.formats = None + + # WARNING: don't store context, or you'll create a loop with the Cursor + if context: + self._adapters = context.adapters + self._conn = context.connection + else: + self._adapters = postgres.adapters + self._conn = None + + # mapping fmt, class -> Dumper instance + self._dumpers: DefaultDict[PyFormat, DumperCache] + self._dumpers = defaultdict(dict) + + # mapping fmt, oid -> Dumper instance + # Not often used, so create it only if needed. + self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]] + self._oid_dumpers = None + + # mapping fmt, oid -> Loader instance + self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {}) + + self._row_dumpers: Optional[List["Dumper"]] = None + + # sequence of load functions from value to python + # the length of the result columns + self._row_loaders: List[LoadFunc] = [] + + # mapping oid -> type sql representation + self._oid_types: Dict[int, bytes] = {} + + self._encoding = "" + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + """ + Return a Transformer from an AdaptContext. + + If the context is a Transformer instance, just return it. + """ + if isinstance(context, Transformer): + return context + else: + return cls(context) + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + return self._conn + + @property + def encoding(self) -> str: + if not self._encoding: + conn = self.connection + self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" + return self._encoding + + @property + def adapters(self) -> "AdaptersMap": + return self._adapters + + @property + def pgresult(self) -> Optional["PGresult"]: + return self._pgresult + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None, + ) -> None: + self._pgresult = result + + if not result: + self._nfields = self._ntuples = 0 + if set_loaders: + self._row_loaders = [] + return + + self._ntuples = result.ntuples + nf = self._nfields = result.nfields + + if not set_loaders: + return + + if not nf: + self._row_loaders = [] + return + + fmt: pq.Format + fmt = result.fformat(0) if format is None else format # type: ignore + self._row_loaders = [ + self.get_loader(result.ftype(i), fmt).load for i in range(nf) + ] + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types] + self.types = tuple(types) + self.formats = [format] * len(types) + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + self._row_loaders = [self.get_loader(oid, format).load for oid in types] + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + nparams = len(params) + out: List[Optional[Buffer]] = [None] * nparams + + # If we have dumpers, it means set_dumper_types had been called, in + # which case self.types and self.formats are set to sequences of the + # right size. + if self._row_dumpers: + for i in range(nparams): + param = params[i] + if param is not None: + out[i] = self._row_dumpers[i].dump(param) + return out + + types = [self._get_none_oid()] * nparams + pqformats = [TEXT] * nparams + + for i in range(nparams): + param = params[i] + if param is None: + continue + dumper = self.get_dumper(param, formats[i]) + out[i] = dumper.dump(param) + types[i] = dumper.oid + pqformats[i] = dumper.format + + self.types = tuple(types) + self.formats = pqformats + + return out + + def as_literal(self, obj: Any) -> bytes: + dumper = self.get_dumper(obj, PY_TEXT) + rv = dumper.quote(obj) + # If the result is quoted, and the oid not unknown or text, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + oid = dumper.oid + if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID: + try: + type_sql = self._oid_types[oid] + except KeyError: + ti = self.adapters.types.get(oid) + if ti: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + else: + type_sql = b"" + self._oid_types[oid] = type_sql + + if type_sql: + rv = b"%s::%s" % (rv, type_sql) + + if not isinstance(rv, bytes): + rv = bytes(rv) + return rv + + def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Return a Dumper instance to dump `!obj`. + """ + # Normally, the type of the object dictates how to dump it + key = type(obj) + + # Reuse an existing Dumper class for objects of the same type + cache = self._dumpers[format] + try: + dumper = cache[key] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper(key, format) + cache[key] = dumper = dcls(key, self) + + # Check if the dumper requires an upgrade to handle this specific value + key1 = dumper.get_key(obj, format) + if key1 is key: + return dumper + + # If it does, ask the dumper to create its own upgraded version + try: + return cache[key1] + except KeyError: + dumper = cache[key1] = dumper.upgrade(obj, format) + return dumper + + def _get_none_oid(self) -> int: + try: + return self._none_oid + except AttributeError: + pass + + try: + rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid + except KeyError: + raise e.InterfaceError("None dumper not found") + + return rv + + def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper": + """ + Return a Dumper to dump an object to the type with given oid. + """ + if not self._oid_dumpers: + self._oid_dumpers = ({}, {}) + + # Reuse an existing Dumper class for objects of the same type + cache = self._oid_dumpers[format] + try: + return cache[oid] + except KeyError: + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper_by_oid(oid, format) + cache[oid] = dumper = dcls(NoneType, self) + + return dumper + + def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: + res = self._pgresult + if not res: + raise e.InterfaceError("result not set") + + if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): + raise e.InterfaceError( + f"rows must be included between 0 and {self._ntuples}" + ) + + records = [] + for row in range(row0, row1): + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + records.append(make_row(record)) + + return records + + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: + res = self._pgresult + if not res: + return None + + if not 0 <= row < self._ntuples: + return None + + record: List[Any] = [None] * self._nfields + for col in range(self._nfields): + val = res.get_value(row, col) + if val is not None: + record[col] = self._row_loaders[col](val) + + return make_row(record) + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + if len(self._row_loaders) != len(record): + raise e.ProgrammingError( + f"cannot load sequence of {len(record)} items:" + f" {len(self._row_loaders)} loaders registered" + ) + + return tuple( + (self._row_loaders[i](val) if val is not None else None) + for i, val in enumerate(record) + ) + + def get_loader(self, oid: int, format: pq.Format) -> "Loader": + try: + return self._loaders[format][oid] + except KeyError: + pass + + loader_cls = self._adapters.get_loader(oid, format) + if not loader_cls: + loader_cls = self._adapters.get_loader(INVALID_OID, format) + if not loader_cls: + raise e.InterfaceError("unknown oid loader not found") + loader = self._loaders[format][oid] = loader_cls(oid, self) + return loader diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py new file mode 100644 index 0000000..2f1a24d --- /dev/null +++ b/psycopg/psycopg/_typeinfo.py @@ -0,0 +1,461 @@ +""" +Information about PostgreSQL types + +These types allow to read information from the system catalog and provide +information to the adapters if needed. +""" + +# Copyright (C) 2020 The Psycopg Team +from enum import Enum +from typing import Any, Dict, Iterator, Optional, overload +from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import errors as e +from .abc import AdaptContext +from .rows import dict_row + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + from .sql import Identifier + +T = TypeVar("T", bound="TypeInfo") +RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]] + + +class TypeInfo: + """ + Hold information about a PostgreSQL base type. + """ + + __module__ = "psycopg.types" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + delimiter: str = ",", + ): + self.name = name + self.oid = oid + self.array_oid = array_oid + self.regtype = regtype or name + self.delimiter = delimiter + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__qualname__}:" + f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>" + ) + + @overload + @classmethod + def fetch( + cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"] + ) -> Optional[T]: + ... + + @overload + @classmethod + async def fetch( + cls: Type[T], + conn: "AsyncConnection[Any]", + name: Union[str, "Identifier"], + ) -> Optional[T]: + ... + + @classmethod + def fetch( + cls: Type[T], + conn: "Union[Connection[Any], AsyncConnection[Any]]", + name: Union[str, "Identifier"], + ) -> Any: + """Query a system catalog to read information about a type.""" + from .sql import Composable + from .connection_async import AsyncConnection + + if isinstance(name, Composable): + name = name.as_string(conn) + + if isinstance(conn, AsyncConnection): + return cls._fetch_async(conn, name) + + # This might result in a nested transaction. What we want is to leave + # the function with the connection in the state we found (either idle + # or intrans) + try: + with conn.transaction(): + with conn.cursor(binary=True, row_factory=dict_row) as cur: + cur.execute(cls._get_info_query(conn), {"name": name}) + recs = cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + async def _fetch_async( + cls: Type[T], conn: "AsyncConnection[Any]", name: str + ) -> Optional[T]: + """ + Query a system catalog to read information about a type. + + Similar to `fetch()` but can use an asynchronous connection. + """ + try: + async with conn.transaction(): + async with conn.cursor(binary=True, row_factory=dict_row) as cur: + await cur.execute(cls._get_info_query(conn), {"name": name}) + recs = await cur.fetchall() + except e.UndefinedObject: + return None + + return cls._from_records(name, recs) + + @classmethod + def _from_records( + cls: Type[T], name: str, recs: Sequence[Dict[str, Any]] + ) -> Optional[T]: + if len(recs) == 1: + return cls(**recs[0]) + elif not recs: + return None + else: + raise e.ProgrammingError(f"found {len(recs)} different types named {name}") + + def register(self, context: Optional[AdaptContext] = None) -> None: + """ + Register the type information, globally or in the specified `!context`. + """ + if context: + types = context.adapters.types + else: + from . import postgres + + types = postgres.types + + types.add(self) + + if self.array_oid: + from .types.array import register_array + + register_array(self, context) + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + typname AS name, oid, typarray AS array_oid, + oid::regtype::text AS regtype, typdelim AS delimiter +FROM pg_type t +WHERE t.oid = %(name)s::regtype +ORDER BY t.oid +""" + + def _added(self, registry: "TypesRegistry") -> None: + """Method called by the `!registry` when the object is added there.""" + pass + + +class RangeInfo(TypeInfo): + """Manage information about a range type.""" + + __module__ = "psycopg.types.range" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngtypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map ranges subtypes to info + registry._registry[RangeInfo, self.subtype_oid] = self + + +class MultirangeInfo(TypeInfo): + """Manage information about a multirange type.""" + + # TODO: expose to multirange module once added + # __module__ = "psycopg.types.multirange" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + range_oid: int, + subtype_oid: int, + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.range_oid = range_oid + self.subtype_oid = subtype_oid + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + if conn.info.server_version < 140000: + raise e.NotSupportedError( + "multirange types are only available from PostgreSQL 14" + ) + return """\ +SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid +FROM pg_type t +JOIN pg_range r ON t.oid = r.rngmultitypid +WHERE t.oid = %(name)s::regtype +""" + + def _added(self, registry: "TypesRegistry") -> None: + # Map multiranges ranges and subtypes to info + registry._registry[MultirangeInfo, self.range_oid] = self + registry._registry[MultirangeInfo, self.subtype_oid] = self + + +class CompositeInfo(TypeInfo): + """Manage information about a composite type.""" + + __module__ = "psycopg.types.composite" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + *, + regtype: str = "", + field_names: Sequence[str], + field_types: Sequence[int], + ): + super().__init__(name, oid, array_oid, regtype=regtype) + self.field_names = field_names + self.field_types = field_types + # Will be set by register() if the `factory` is a type + self.python_type: Optional[type] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + t.oid::regtype::text AS regtype, + coalesce(a.fnames, '{}') AS field_names, + coalesce(a.ftypes, '{}') AS field_types +FROM pg_type t +LEFT JOIN ( + SELECT + attrelid, + array_agg(attname) AS fnames, + array_agg(atttypid) AS ftypes + FROM ( + SELECT a.attrelid, a.attname, a.atttypid + FROM pg_attribute a + JOIN pg_type t ON t.typrelid = a.attrelid + WHERE t.oid = %(name)s::regtype + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + ) x + GROUP BY attrelid +) a ON a.attrelid = t.typrelid +WHERE t.oid = %(name)s::regtype +""" + + +class EnumInfo(TypeInfo): + """Manage information about an enum type.""" + + __module__ = "psycopg.types.enum" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + labels: Sequence[str], + ): + super().__init__(name, oid, array_oid) + self.labels = labels + # Will be set by register_enum() + self.enum: Optional[Type[Enum]] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT name, oid, array_oid, array_agg(label) AS labels +FROM ( + SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + e.enumlabel AS label + FROM pg_type t + LEFT JOIN pg_enum e + ON e.enumtypid = t.oid + WHERE t.oid = %(name)s::regtype + ORDER BY e.enumsortorder +) x +GROUP BY name, oid, array_oid +""" + + +class TypesRegistry: + """ + Container for the information about types in a database. + """ + + __module__ = "psycopg.types" + + def __init__(self, template: Optional["TypesRegistry"] = None): + self._registry: Dict[RegistryKey, TypeInfo] + + # Make a shallow copy: it will become a proper copy if the registry + # is edited. + if template: + self._registry = template._registry + self._own_state = False + template._own_state = False + else: + self.clear() + + def clear(self) -> None: + self._registry = {} + self._own_state = True + + def add(self, info: TypeInfo) -> None: + self._ensure_own_state() + if info.oid: + self._registry[info.oid] = info + if info.array_oid: + self._registry[info.array_oid] = info + self._registry[info.name] = info + + if info.regtype and info.regtype not in self._registry: + self._registry[info.regtype] = info + + # Allow info to customise further their relation with the registry + info._added(self) + + def __iter__(self) -> Iterator[TypeInfo]: + seen = set() + for t in self._registry.values(): + if id(t) not in seen: + seen.add(id(t)) + yield t + + @overload + def __getitem__(self, key: Union[str, int]) -> TypeInfo: + ... + + @overload + def __getitem__(self, key: Tuple[Type[T], int]) -> T: + ... + + def __getitem__(self, key: RegistryKey) -> TypeInfo: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Raise KeyError if not found. + """ + if isinstance(key, str): + if key.endswith("[]"): + key = key[:-2] + elif not isinstance(key, (int, tuple)): + raise TypeError(f"the key must be an oid or a name, got {type(key)}") + try: + return self._registry[key] + except KeyError: + raise KeyError(f"couldn't find the type {key!r} in the types registry") + + @overload + def get(self, key: Union[str, int]) -> Optional[TypeInfo]: + ... + + @overload + def get(self, key: Tuple[Type[T], int]) -> Optional[T]: + ... + + def get(self, key: RegistryKey) -> Optional[TypeInfo]: + """ + Return info about a type, specified by name or oid + + :param key: the name or oid of the type to look for. + + Unlike `__getitem__`, return None if not found. + """ + try: + return self[key] + except KeyError: + return None + + def get_oid(self, name: str) -> int: + """ + Return the oid of a PostgreSQL type by name. + + :param key: the name of the type to look for. + + Return the array oid if the type ends with "``[]``" + + Raise KeyError if the name is unknown. + """ + t = self[name] + if name.endswith("[]"): + return t.array_oid + else: + return t.oid + + def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]: + """ + Return info about a `TypeInfo` subclass by its element name or oid. + + :param cls: the subtype of `!TypeInfo` to look for. Currently + supported are `~psycopg.types.range.RangeInfo` and + `~psycopg.types.multirange.MultirangeInfo`. + :param subtype: The name or OID of the subtype of the element to look for. + :return: The `!TypeInfo` object of class `!cls` whose subtype is + `!subtype`. `!None` if the element or its range are not found. + """ + try: + info = self[subtype] + except KeyError: + return None + return self.get((cls, info.oid)) + + def _ensure_own_state(self) -> None: + # Time to write! so, copy. + if not self._own_state: + self._registry = self._registry.copy() + self._own_state = True diff --git a/psycopg/psycopg/_tz.py b/psycopg/psycopg/_tz.py new file mode 100644 index 0000000..813ed62 --- /dev/null +++ b/psycopg/psycopg/_tz.py @@ -0,0 +1,44 @@ +""" +Timezone utility functions. +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import Dict, Optional, Union +from datetime import timezone, tzinfo + +from .pq.abc import PGconn +from ._compat import ZoneInfo + +logger = logging.getLogger("psycopg") + +_timezones: Dict[Union[None, bytes], tzinfo] = { + None: timezone.utc, + b"UTC": timezone.utc, +} + + +def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo: + """Return the Python timezone info of the connection's timezone.""" + tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None + try: + return _timezones[tzname] + except KeyError: + sname = tzname.decode() if tzname else "UTC" + try: + zi: tzinfo = ZoneInfo(sname) + except (KeyError, OSError): + logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname) + zi = timezone.utc + except Exception as ex: + logger.warning( + "error handling PostgreSQL timezone: %r; will use UTC (%s - %s)", + sname, + type(ex).__name__, + ex, + ) + zi = timezone.utc + + _timezones[tzname] = zi + return zi diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py new file mode 100644 index 0000000..f861741 --- /dev/null +++ b/psycopg/psycopg/_wrappers.py @@ -0,0 +1,137 @@ +""" +Wrappers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Wrappers to force numbers to be cast as specific PostgreSQL types + +# These types are implemented here but exposed by `psycopg.types.numeric`. +# They are defined here to avoid a circular import. +_MODULE = "psycopg.types.numeric" + + +class Int2(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int2": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int4(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Int8(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Int8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class IntNumeric(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "IntNumeric": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float4(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float4": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Float8(float): + """ + Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: float) -> "Float8": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + +class Oid(int): + """ + Force dumping a Python `!int` as a PostgreSQL :sql:`oid`. + """ + + __module__ = _MODULE + __slots__ = () + + def __new__(cls, arg: int) -> "Oid": + return super().__new__(cls, arg) + + def __str__(self) -> str: + return super().__repr__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py new file mode 100644 index 0000000..80c8fbf --- /dev/null +++ b/psycopg/psycopg/abc.py @@ -0,0 +1,266 @@ +""" +Protocol objects representing different implementations of the same classes. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Generator, Mapping +from typing import List, Optional, Sequence, Tuple, TypeVar, Union +from typing import TYPE_CHECKING +from typing_extensions import TypeAlias + +from . import pq +from ._enums import PyFormat as PyFormat +from ._compat import Protocol, LiteralString + +if TYPE_CHECKING: + from . import sql + from .rows import Row, RowMaker + from .pq.abc import PGresult + from .waiting import Wait, Ready + from .connection import BaseConnection + from ._adapters_map import AdaptersMap + +NoneType: type = type(None) + +# An object implementing the buffer protocol +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + +Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"] +Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] +ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") +PipelineCommand: TypeAlias = Callable[[], None] +DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]] + +# Waiting protocol types + +RV = TypeVar("RV") + +PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV] +"""Generator for processes where the connection file number can change. + +This can happen in connection and reset, but not in normal querying. +""" + +PQGen: TypeAlias = Generator["Wait", "Ready", RV] +"""Generator for processes where the connection file number won't change. +""" + + +class WaitFunc(Protocol): + """ + Wait on the connection which generated `PQgen` and return its final result. + """ + + def __call__( + self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None + ) -> RV: + ... + + +# Adaptation types + +DumpFunc: TypeAlias = Callable[[Any], Buffer] +LoadFunc: TypeAlias = Callable[[Buffer], Any] + + +class AdaptContext(Protocol): + """ + A context describing how types are adapted. + + Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`, + `~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`. + + Note that this is a `~typing.Protocol`, so objects implementing + `!AdaptContext` don't need to explicitly inherit from this class. + + """ + + @property + def adapters(self) -> "AdaptersMap": + """The adapters configuration that this object uses.""" + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + """The connection used by this object, if available. + + :rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None` + """ + ... + + +class Dumper(Protocol): + """ + Convert Python objects of type `!cls` to PostgreSQL representation. + """ + + format: pq.Format + """ + The format that this class `dump()` method produces, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + oid: int + """The oid to pass to the server, if known; 0 otherwise (class attribute).""" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + ... + + def dump(self, obj: Any) -> Buffer: + """Convert the object `!obj` to PostgreSQL representation. + + :param obj: the object to convert. + """ + ... + + def quote(self, obj: Any) -> Buffer: + """Convert the object `!obj` to escaped representation. + + :param obj: the object to convert. + """ + ... + + def get_key(self, obj: Any, format: PyFormat) -> DumperKey: + """Return an alternative key to upgrade the dumper to represent `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Normally the type of the object is all it takes to define how to dump + the object to the database. For instance, a Python `~datetime.date` can + be simply converted into a PostgreSQL :sql:`date`. + + In a few cases, just the type is not enough. For example: + + - A Python `~datetime.datetime` could be represented as a + :sql:`timestamptz` or a :sql:`timestamp`, according to whether it + specifies a `!tzinfo` or not. + + - A Python int could be stored as several Postgres types: int2, int4, + int8, numeric. If a type too small is used, it may result in an + overflow. If a type too large is used, PostgreSQL may not want to + cast it to a smaller type. + + - Python lists should be dumped according to the type they contain to + convert them to e.g. array of strings, array of ints (and which + size of int?...) + + In these cases, a dumper can implement `!get_key()` and return a new + class, or sequence of classes, that can be used to identify the same + dumper again. If the mechanism is not needed, the method should return + the same `!cls` object passed in the constructor. + + If a dumper implements `get_key()` it should also implement + `upgrade()`. + + """ + ... + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """Return a new dumper to manage `!obj`. + + :param obj: The object to convert + :param format: The format to convert to + + Once `Transformer.get_dumper()` has been notified by `get_key()` that + this Dumper class cannot handle `!obj` itself, it will invoke + `!upgrade()`, which should return a new `Dumper` instance, which will + be reused for every objects for which `!get_key()` returns the same + result. + """ + ... + + +class Loader(Protocol): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format + """ + The format that this class `load()` method can convert, + `~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`. + + This is a class attribute. + """ + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + ... + + def load(self, data: Buffer) -> Any: + """ + Convert the data returned by the database into a Python object. + + :param data: the data to convert. + """ + ... + + +class Transformer(Protocol): + + types: Optional[Tuple[int, ...]] + formats: Optional[List[pq.Format]] + + def __init__(self, context: Optional[AdaptContext] = None): + ... + + @classmethod + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": + ... + + @property + def connection(self) -> Optional["BaseConnection[Any]"]: + ... + + @property + def encoding(self) -> str: + ... + + @property + def adapters(self) -> "AdaptersMap": + ... + + @property + def pgresult(self) -> Optional["PGresult"]: + ... + + def set_pgresult( + self, + result: Optional["PGresult"], + *, + set_loaders: bool = True, + format: Optional[pq.Format] = None + ) -> None: + ... + + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: + ... + + def dump_sequence( + self, params: Sequence[Any], formats: Sequence[PyFormat] + ) -> Sequence[Optional[Buffer]]: + ... + + def as_literal(self, obj: Any) -> bytes: + ... + + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: + ... + + def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]: + ... + + def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: + ... + + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: + ... + + def get_loader(self, oid: int, format: pq.Format) -> Loader: + ... diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py new file mode 100644 index 0000000..7ec4a55 --- /dev/null +++ b/psycopg/psycopg/adapt.py @@ -0,0 +1,162 @@ +""" +Entry point into the adaptation system. +""" + +# Copyright (C) 2020 The Psycopg Team + +from abc import ABC, abstractmethod +from typing import Any, Optional, Type, TYPE_CHECKING + +from . import pq, abc +from . import _adapters_map +from ._enums import PyFormat as PyFormat +from ._cmodule import _psycopg + +if TYPE_CHECKING: + from .connection import BaseConnection + +AdaptersMap = _adapters_map.AdaptersMap +Buffer = abc.Buffer + +ORD_BS = ord("\\") + + +class Dumper(abc.Dumper, ABC): + """ + Convert Python object of the type `!cls` to PostgreSQL representation. + """ + + oid: int = 0 + """The oid to pass to the server, if known.""" + + format: pq.Format = pq.Format.TEXT + """The format of the data dumped.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + self.cls = cls + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + def __repr__(self) -> str: + return ( + f"<{type(self).__module__}.{type(self).__qualname__}" + f" (oid={self.oid}) at 0x{id(self):x}>" + ) + + @abstractmethod + def dump(self, obj: Any) -> Buffer: + ... + + def quote(self, obj: Any) -> Buffer: + """ + By default return the `dump()` value quoted and sanitised, so + that the result can be used to build a SQL string. This works well + for most types and you won't likely have to implement this method in a + subclass. + """ + value = self.dump(obj) + + if self.connection: + esc = pq.Escaping(self.connection.pgconn) + # escaping and quoting + return esc.escape_literal(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + # No quoting, only quote escaping, random bs escaping. See further. + esc = pq.Escaping() + out = esc.escape_string(value) + + # b"\\" in memoryview doesn't work so search for the ascii value + if ORD_BS not in out: + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + return b"'" + out + b"'" + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + rv: bytes = b" E'" + out + b"'" + if esc.escape_string(b"\\") == b"\\": + rv = rv.replace(b"\\", b"\\\\") + return rv + + def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey: + """ + Implementation of the `~psycopg.abc.Dumper.get_key()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation returns the `!cls` passed in the constructor. + Subclasses needing to specialise the PostgreSQL type according to the + *value* of the object dumped (not only according to to its type) + should override this class. + + """ + return self.cls + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + """ + Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the + `~psycopg.abc.Dumper` protocol. Look at its definition for details. + + This implementation just returns `!self`. If a subclass implements + `get_key()` it should probably override `!upgrade()` too. + """ + return self + + +class Loader(abc.Loader, ABC): + """ + Convert PostgreSQL values with type OID `!oid` to Python objects. + """ + + format: pq.Format = pq.Format.TEXT + """The format of the data loaded.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + self.oid = oid + self.connection: Optional["BaseConnection[Any]"] = ( + context.connection if context else None + ) + + @abstractmethod + def load(self, data: Buffer) -> Any: + """Convert a PostgreSQL value to a Python object.""" + ... + + +Transformer: Type["abc.Transformer"] + +# Override it with fast object if available +if _psycopg: + Transformer = _psycopg.Transformer +else: + from . import _transform + + Transformer = _transform.Transformer + + +class RecursiveDumper(Dumper): + """Dumper with a transformer to help dumping recursive types.""" + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self._tx = Transformer.from_context(context) + + +class RecursiveLoader(Loader): + """Loader with a transformer to help loading recursive types.""" + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer.from_context(context) diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py new file mode 100644 index 0000000..6271ec5 --- /dev/null +++ b/psycopg/psycopg/client_cursor.py @@ -0,0 +1,95 @@ +""" +psycopg client-side binding cursors +""" + +# Copyright (C) 2022 The Psycopg Team + +from typing import Optional, Tuple, TYPE_CHECKING +from functools import partial + +from ._queries import PostgresQuery, PostgresClientQuery + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params +from .rows import Row +from .cursor import BaseCursor, Cursor +from ._preparing import Prepare +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from typing import Any # noqa: F401 + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + + +class ClientCursorMixin(BaseCursor[ConnectionType, Row]): + def mogrify(self, query: Query, params: Optional[Params] = None) -> str: + """ + Return the query and parameters merged. + + Parameters are adapted and merged to the query the same way that + `!execute()` would do. + + """ + self._tx = adapt.Transformer(self) + pgq = self._convert_query(query, params) + return pgq.query.decode(self._tx.encoding) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if fmt == BINARY: + raise e.NotSupportedError( + "client-side cursors don't support binary results" + ) + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial(self._pgconn.send_query_params, query.query, None) + ) + elif force_extended: + self._pgconn.send_query_params(query.query, None) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresClientQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return (Prepare.NO, b"") + + +class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + + +class AsyncClientCursor( + ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py new file mode 100644 index 0000000..78ad577 --- /dev/null +++ b/psycopg/psycopg/connection.py @@ -0,0 +1,1031 @@ +""" +psycopg connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +import threading +from types import TracebackType +from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator +from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union +from typing import overload, TYPE_CHECKING +from weakref import ref, ReferenceType +from warnings import warn +from functools import partial +from contextlib import contextmanager +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from . import waiting +from . import postgres +from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import PQGen, PQGenConn +from .sql import Composable, SQL +from ._tpc import Xid +from .rows import Row, RowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .cursor import Cursor +from ._compat import LiteralString +from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from ._pipeline import BasePipeline, Pipeline +from .generators import notifies, connect, execute +from ._encodings import pgconn_encoding +from ._preparing import PrepareManager +from .transaction import Transaction +from .server_cursor import ServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn, PGresult + from psycopg_pool.base import BasePool + + +# Row Type variable for Cursor (when it needs to be distinguished from the +# connection's one) +CursorRow = TypeVar("CursorRow") + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class Notify(NamedTuple): + """An asynchronous notification received from the database.""" + + channel: str + """The name of the channel on which the notification was received.""" + + payload: str + """The message attached to the notification.""" + + pid: int + """The PID of the backend process which sent the notification.""" + + +Notify.__module__ = "psycopg" + +NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] +NotifyHandler: TypeAlias = Callable[[Notify], None] + + +class BaseConnection(Generic[Row]): + """ + Base class for different types of connections. + + Share common functionalities such as access to the wrapped PGconn, but + allow different interfaces (sync/async). + """ + + # DBAPI2 exposed exceptions + Warning = e.Warning + Error = e.Error + InterfaceError = e.InterfaceError + DatabaseError = e.DatabaseError + DataError = e.DataError + OperationalError = e.OperationalError + IntegrityError = e.IntegrityError + InternalError = e.InternalError + ProgrammingError = e.ProgrammingError + NotSupportedError = e.NotSupportedError + + # Enums useful for the connection + ConnStatus = pq.ConnStatus + TransactionStatus = pq.TransactionStatus + + def __init__(self, pgconn: "PGconn"): + self.pgconn = pgconn + self._autocommit = False + + # None, but set to a copy of the global adapters map as soon as requested. + self._adapters: Optional[AdaptersMap] = None + + self._notice_handlers: List[NoticeHandler] = [] + self._notify_handlers: List[NotifyHandler] = [] + + # Number of transaction blocks currently entered + self._num_transactions = 0 + + self._closed = False # closed by an explicit close() + self._prepared: PrepareManager = PrepareManager() + self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared + + wself = ref(self) + pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) + pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) + + # Attribute is only set if the connection is from a pool so we can tell + # apart a connection in the pool too (when _pool = None) + self._pool: Optional["BasePool[Any]"] + + self._pipeline: Optional[BasePipeline] = None + + # Time after which the connection should be closed + self._expire_at: float + + self._isolation_level: Optional[IsolationLevel] = None + self._read_only: Optional[bool] = None + self._deferrable: Optional[bool] = None + self._begin_statement = b"" + + def __del__(self) -> None: + # If fails on connection we might not have this attribute yet + if not hasattr(self, "pgconn"): + return + + # Connection correctly closed + if self.closed: + return + + # Connection in a pool so terminating with the program is normal + if hasattr(self, "_pool"): + return + + warn( + f"connection {self} was deleted while still open." + " Please use 'with' or '.close()' to close the connection", + ResourceWarning, + ) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def closed(self) -> bool: + """`!True` if the connection is closed.""" + return self.pgconn.status == BAD + + @property + def broken(self) -> bool: + """ + `!True` if the connection was interrupted. + + A broken connection is always `closed`, but wasn't closed in a clean + way, such as using `close()` or a `!with` block. + """ + return self.pgconn.status == BAD and not self._closed + + @property + def autocommit(self) -> bool: + """The autocommit state of the connection.""" + return self._autocommit + + @autocommit.setter + def autocommit(self, value: bool) -> None: + self._set_autocommit(value) + + def _set_autocommit(self, value: bool) -> None: + raise NotImplementedError + + def _set_autocommit_gen(self, value: bool) -> PQGen[None]: + yield from self._check_intrans_gen("autocommit") + self._autocommit = bool(value) + + @property + def isolation_level(self) -> Optional[IsolationLevel]: + """ + The isolation level of the new transactions started on the connection. + """ + return self._isolation_level + + @isolation_level.setter + def isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._set_isolation_level(value) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + raise NotImplementedError + + def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]: + yield from self._check_intrans_gen("isolation_level") + self._isolation_level = IsolationLevel(value) if value is not None else None + self._begin_statement = b"" + + @property + def read_only(self) -> Optional[bool]: + """ + The read-only state of the new transactions started on the connection. + """ + return self._read_only + + @read_only.setter + def read_only(self, value: Optional[bool]) -> None: + self._set_read_only(value) + + def _set_read_only(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("read_only") + self._read_only = bool(value) + self._begin_statement = b"" + + @property + def deferrable(self) -> Optional[bool]: + """ + The deferrable state of the new transactions started on the connection. + """ + return self._deferrable + + @deferrable.setter + def deferrable(self, value: Optional[bool]) -> None: + self._set_deferrable(value) + + def _set_deferrable(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("deferrable") + self._deferrable = bool(value) + self._begin_statement = b"" + + def _check_intrans_gen(self, attribute: str) -> PQGen[None]: + # Raise an exception if we are in a transaction + status = self.pgconn.transaction_status + if status == IDLE and self._pipeline: + yield from self._pipeline._sync_gen() + status = self.pgconn.transaction_status + if status != IDLE: + if self._num_transactions: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection.transaction() context in progress" + ) + else: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection in transaction status " + f"{pq.TransactionStatus(status).name}" + ) + + @property + def info(self) -> ConnectionInfo: + """A `ConnectionInfo` attribute to inspect connection properties.""" + return ConnectionInfo(self.pgconn) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + self._adapters = AdaptersMap(postgres.adapters) + + return self._adapters + + @property + def connection(self) -> "BaseConnection[Row]": + # implement the AdaptContext protocol + return self + + def fileno(self) -> int: + """Return the file descriptor of the connection. + + This function allows to use the connection as file-like object in + functions waiting for readiness, such as the ones defined in the + `selectors` module. + """ + return self.pgconn.socket + + def cancel(self) -> None: + """Cancel the current operation on the connection.""" + # No-op if the connection is closed + # this allows to use the method as callback handler without caring + # about its life. + if self.closed: + return + + if self._tpc and self._tpc[1]: + raise e.ProgrammingError( + "cancel() cannot be used with a prepared two-phase transaction" + ) + + c = self.pgconn.get_cancel() + c.cancel() + + def add_notice_handler(self, callback: NoticeHandler) -> None: + """ + Register a callable to be invoked when a notice message is received. + + :param callback: the callback to call upon message received. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.append(callback) + + def remove_notice_handler(self, callback: NoticeHandler) -> None: + """ + Unregister a notice message callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.remove(callback) + + @staticmethod + def _notice_handler( + wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" + ) -> None: + self = wself() + if not (self and self._notice_handlers): + return + + diag = e.Diagnostic(res, pgconn_encoding(self.pgconn)) + for cb in self._notice_handlers: + try: + cb(diag) + except Exception as ex: + logger.exception("error processing notice callback '%s': %s", cb, ex) + + def add_notify_handler(self, callback: NotifyHandler) -> None: + """ + Register a callable to be invoked whenever a notification is received. + + :param callback: the callback to call upon notification received. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.append(callback) + + def remove_notify_handler(self, callback: NotifyHandler) -> None: + """ + Unregister a notification callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.remove(callback) + + @staticmethod + def _notify_handler( + wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify + ) -> None: + self = wself() + if not (self and self._notify_handlers): + return + + enc = pgconn_encoding(self.pgconn) + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + for cb in self._notify_handlers: + cb(n) + + @property + def prepare_threshold(self) -> Optional[int]: + """ + Number of times a query is executed before it is prepared. + + - If it is set to 0, every query is prepared the first time it is + executed. + - If it is set to `!None`, prepared statements are disabled on the + connection. + + Default value: 5 + """ + return self._prepared.prepare_threshold + + @prepare_threshold.setter + def prepare_threshold(self, value: Optional[int]) -> None: + self._prepared.prepare_threshold = value + + @property + def prepared_max(self) -> int: + """ + Maximum number of prepared statements on the connection. + + Default value: 100 + """ + return self._prepared.prepared_max + + @prepared_max.setter + def prepared_max(self, value: int) -> None: + self._prepared.prepared_max = value + + # Generators to perform high-level operations on the connection + # + # These operations are expressed in terms of non-blocking generators + # and the task of waiting when needed (when the generators yield) is left + # to the connections subclass, which might wait either in blocking mode + # or through asyncio. + # + # All these generators assume exclusive access to the connection: subclasses + # should have a lock and hold it before calling and consuming them. + + @classmethod + def _connect_gen( + cls: Type[ConnectionType], + conninfo: str = "", + *, + autocommit: bool = False, + ) -> PQGenConn[ConnectionType]: + """Generator to connect to the database and create a new instance.""" + pgconn = yield from connect(conninfo) + conn = cls(pgconn) + conn._autocommit = bool(autocommit) + return conn + + def _exec_command( + self, command: Query, result_format: pq.Format = TEXT + ) -> PQGen[Optional["PGresult"]]: + """ + Generator to send a command and receive the result to the backend. + + Only used to implement internal commands such as "commit", with eventual + arguments bound client-side. The cursor can do more complex stuff. + """ + self._check_connection_ok() + + if isinstance(command, str): + command = command.encode(pgconn_encoding(self.pgconn)) + elif isinstance(command, Composable): + command = command.as_bytes(self) + + if self._pipeline: + cmd = partial( + self.pgconn.send_query_params, + command, + None, + result_format=result_format, + ) + self._pipeline.command_queue.append(cmd) + self._pipeline.result_queue.append(None) + return None + + self.pgconn.send_query_params(command, None, result_format=result_format) + + result = (yield from execute(self.pgconn))[-1] + if result.status != COMMAND_OK and result.status != TUPLES_OK: + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + else: + raise e.InterfaceError( + f"unexpected result {pq.ExecStatus(result.status).name}" + f" from command {command.decode()!r}" + ) + return result + + def _check_connection_ok(self) -> None: + if self.pgconn.status == OK: + return + + if self.pgconn.status == BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + "cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + + def _start_query(self) -> PQGen[None]: + """Generator to start a transaction if necessary.""" + if self._autocommit: + return + + if self.pgconn.transaction_status != IDLE: + return + + yield from self._exec_command(self._get_tx_start_command()) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _get_tx_start_command(self) -> bytes: + if self._begin_statement: + return self._begin_statement + + parts = [b"BEGIN"] + + if self.isolation_level is not None: + val = IsolationLevel(self.isolation_level) + parts.append(b"ISOLATION LEVEL") + parts.append(val.name.replace("_", " ").encode()) + + if self.read_only is not None: + parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") + + if self.deferrable is not None: + parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") + + self._begin_statement = b" ".join(parts) + return self._begin_statement + + def _commit_gen(self) -> PQGen[None]: + """Generator implementing `Connection.commit()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "commit() cannot be used during a two-phase transaction" + ) + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"COMMIT") + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _rollback_gen(self) -> PQGen[None]: + """Generator implementing `Connection.rollback()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "rollback() cannot be used during a two-phase transaction" + ) + + # Get out of a "pipeline aborted" state + if self._pipeline: + yield from self._pipeline._sync_gen() + + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"ROLLBACK") + self._prepared.clear() + for cmd in self._prepared.get_maintenance_commands(): + yield from self._exec_command(cmd) + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: + """ + Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. + + The argument types and constraints are explained in + :ref:`two-phase-commit`. + + The values passed to the method will be available on the returned + object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. + """ + self._check_tpc() + return Xid.from_parts(format_id, gtrid, bqual) + + def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + self._check_tpc() + + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self.pgconn.transaction_status != IDLE: + raise e.ProgrammingError( + "can't start two-phase transaction: connection in status" + f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" + ) + + if self._autocommit: + raise e.ProgrammingError( + "can't use two-phase transactions in autocommit mode" + ) + + self._tpc = (xid, False) + yield from self._exec_command(self._get_tx_start_command()) + + def _tpc_prepare_gen(self) -> PQGen[None]: + if not self._tpc: + raise e.ProgrammingError( + "'tpc_prepare()' must be called inside a two-phase transaction" + ) + if self._tpc[1]: + raise e.ProgrammingError( + "'tpc_prepare()' cannot be used during a prepared two-phase transaction" + ) + xid = self._tpc[0] + self._tpc = (xid, True) + yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _tpc_finish_gen( + self, action: LiteralString, xid: Union[Xid, str, None] + ) -> PQGen[None]: + fname = f"tpc_{action.lower()}()" + if xid is None: + if not self._tpc: + raise e.ProgrammingError( + f"{fname} without xid must must be" + " called inside a two-phase transaction" + ) + xid = self._tpc[0] + else: + if self._tpc: + raise e.ProgrammingError( + f"{fname} with xid must must be called" + " outside a two-phase transaction" + ) + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self._tpc and not self._tpc[1]: + meth: Callable[[], PQGen[None]] + meth = getattr(self, f"_{action.lower()}_gen") + self._tpc = None + yield from meth() + else: + yield from self._exec_command( + SQL("{} PREPARED {}").format(SQL(action), str(xid)) + ) + self._tpc = None + + def _check_tpc(self) -> None: + """Raise NotSupportedError if TPC is not supported.""" + # TPC supported on every supported PostgreSQL version. + pass + + +class Connection(BaseConnection[Row]): + """ + Wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[Cursor[Row]] + server_cursor_factory: Type[ServerCursor[Row]] + row_factory: RowFactory[Row] + _pipeline: Optional[Pipeline] + _Self = TypeVar("_Self", bound="Connection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = threading.Lock() + self.cursor_factory = Cursor + self.server_cursor_factory = ServerCursor + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[Cursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "Connection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: Optional[RowFactory[Row]] = None, + cursor_factory: Optional[Type[Cursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Any, + ) -> "Connection[Any]": + """ + Connect to a database server and return a new `Connection` instance. + """ + params = cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + self.close() + + @classmethod + def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + :param conninfo: Connection string as received by `~Connection.connect()`. + :param kwargs: Overriding connection arguments as received by `!connect()`. + :return: Connection arguments merged and eventually modified, in a + format similar to `~conninfo.conninfo_to_dict()`. + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + return params + + def close(self) -> None: + """Close the database connection.""" + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> Cursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + ) -> Cursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: RowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> ServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[RowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[Cursor[Any], ServerCursor[Any]]: + """ + Return a new cursor to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[Cursor[Any], ServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> Cursor[Row]: + """Execute a query and return a cursor to read its results.""" + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + with self.lock: + self.wait(self._commit_gen()) + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + with self.lock: + self.wait(self._rollback_gen()) + + @contextmanager + def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> Iterator[Transaction]: + """ + Start a context block with a new transaction or nested transaction. + + :param savepoint_name: Name of the savepoint used to manage a nested + transaction. If `!None`, one will be chosen automatically. + :param force_rollback: Roll back the transaction at the end of the + block even if there were no error (e.g. to try a no-op process). + :rtype: Transaction + """ + tx = Transaction(self, savepoint_name, force_rollback) + if self._pipeline: + with self.pipeline(), tx, self.pipeline(): + yield tx + else: + with tx: + yield tx + + def notifies(self) -> Generator[Notify, None, None]: + """ + Yield `Notify` objects as soon as they are received from the database. + """ + while True: + with self.lock: + try: + ns = self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @contextmanager + def pipeline(self) -> Iterator[Pipeline]: + """Switch the connection into pipeline mode.""" + with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = Pipeline(self) + + try: + with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + """ + Consume a generator operating on the connection. + + The function must be used on generators that don't change connection + fd (i.e. not on connect and reset). + """ + try: + return waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except KeyboardInterrupt: + # On Ctrl-C, try to cancel the query in the server, otherwise + # the connection will remain stuck in ACTIVE state. + c = self.pgconn.get_cancel() + c.cancel() + try: + waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + """Consume a connection generator.""" + return waiting.wait_conn(gen, timeout=timeout) + + def _set_autocommit(self, value: bool) -> None: + with self.lock: + self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + with self.lock: + self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + with self.lock: + self.wait(self._set_deferrable_gen(value)) + + def tpc_begin(self, xid: Union[Xid, str]) -> None: + """ + Begin a TPC transaction with the given transaction ID `!xid`. + """ + with self.lock: + self.wait(self._tpc_begin_gen(xid)) + + def tpc_prepare(self) -> None: + """ + Perform the first phase of a transaction started with `tpc_begin()`. + """ + try: + with self.lock: + self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + """ + Commit a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("COMMIT", xid)) + + def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + """ + Roll back a prepared two-phase transaction. + """ + with self.lock: + self.wait(self._tpc_finish_gen("ROLLBACK", xid)) + + def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + cur.execute(Xid._get_recover_query()) + res = cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + self.rollback() + + return res diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py new file mode 100644 index 0000000..aa02dc0 --- /dev/null +++ b/psycopg/psycopg/connection_async.py @@ -0,0 +1,436 @@ +""" +psycopg async connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import asyncio +import logging +from types import TracebackType +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from . import waiting +from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from ._tpc import Xid +from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async +from ._pipeline import AsyncPipeline +from ._encodings import pgconn_encoding +from .connection import BaseConnection, CursorRow, Notify +from .generators import notifies +from .transaction import AsyncTransaction +from .cursor_async import AsyncCursor +from .server_cursor import AsyncServerCursor + +if TYPE_CHECKING: + from .pq.abc import PGconn + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class AsyncConnection(BaseConnection[Row]): + """ + Asynchronous wrapper for a connection to the database. + """ + + __module__ = "psycopg" + + cursor_factory: Type[AsyncCursor[Row]] + server_cursor_factory: Type[AsyncServerCursor[Row]] + row_factory: AsyncRowFactory[Row] + _pipeline: Optional[AsyncPipeline] + _Self = TypeVar("_Self", bound="AsyncConnection[Any]") + + def __init__( + self, + pgconn: "PGconn", + row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row), + ): + super().__init__(pgconn) + self.row_factory = row_factory + self.lock = asyncio.Lock() + self.cursor_factory = AsyncCursor + self.server_cursor_factory = AsyncServerCursor + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[Row]": + # TODO: returned type should be _Self. See #308. + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: Optional[Type[AsyncCursor[Any]]] = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[TupleRow]": + ... + + @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + context: Optional[AdaptContext] = None, + row_factory: Optional[AsyncRowFactory[Row]] = None, + cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, + **kwargs: Any, + ) -> "AsyncConnection[Any]": + + if sys.platform == "win32": + loop = asyncio.get_running_loop() + if isinstance(loop, asyncio.ProactorEventLoop): + raise e.InterfaceError( + "Psycopg cannot use the 'ProactorEventLoop' to run in async" + " mode. Please use a compatible event loop, for instance by" + " setting 'asyncio.set_event_loop_policy" + "(WindowsSelectorEventLoopPolicy())'" + ) + + params = await cls._get_connection_params(conninfo, **kwargs) + conninfo = make_conninfo(**params) + + try: + rv = await cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit), + timeout=params["connect_timeout"], + ) + except e.Error as ex: + raise ex.with_traceback(None) + + if row_factory: + rv.row_factory = row_factory + if cursor_factory: + rv.cursor_factory = cursor_factory + if context: + rv._adapters = AdaptersMap(context.adapters) + rv.prepare_threshold = prepare_threshold + return rv + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.closed: + return + + if exc_type: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + await self.rollback() + except Exception as exc2: + logger.warning( + "error ignored in rollback on %s: %s", + self, + exc2, + ) + else: + await self.commit() + + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + await self.close() + + @classmethod + async def _get_connection_params( + cls, conninfo: str, **kwargs: Any + ) -> Dict[str, Any]: + """Manipulate connection parameters before connecting. + + .. versionchanged:: 3.1 + Unlike the sync counterpart, perform non-blocking address + resolution and populate the ``hostaddr`` connection parameter, + unless the user has provided one themselves. See + `~psycopg._dns.resolve_hostaddr_async()` for details. + + """ + params = conninfo_to_dict(conninfo, **kwargs) + + # Make sure there is an usable connect_timeout + if "connect_timeout" in params: + params["connect_timeout"] = int(params["connect_timeout"]) + else: + params["connect_timeout"] = None + + # Resolve host addresses in non-blocking way + params = await resolve_hostaddr_async(params) + + return params + + async def close(self) -> None: + if self.closed: + return + self._closed = True + self.pgconn.finish() + + @overload + def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: + ... + + @overload + def cursor( + self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] + ) -> AsyncCursor[CursorRow]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[Row]: + ... + + @overload + def cursor( + self, + name: str, + *, + binary: bool = False, + row_factory: AsyncRowFactory[CursorRow], + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> AsyncServerCursor[CursorRow]: + ... + + def cursor( + self, + name: str = "", + *, + binary: bool = False, + row_factory: Optional[AsyncRowFactory[Any]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: + """ + Return a new `AsyncCursor` to send commands and queries to the connection. + """ + self._check_connection_ok() + + if not row_factory: + row_factory = self.row_factory + + cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]] + if name: + cur = self.server_cursor_factory( + self, + name=name, + row_factory=row_factory, + scrollable=scrollable, + withhold=withhold, + ) + else: + cur = self.cursor_factory(self, row_factory=row_factory) + + if binary: + cur.format = BINARY + + return cur + + async def execute( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: bool = False, + ) -> AsyncCursor[Row]: + try: + cur = self.cursor() + if binary: + cur.format = BINARY + + return await cur.execute(query, params, prepare=prepare) + + except e.Error as ex: + raise ex.with_traceback(None) + + async def commit(self) -> None: + async with self.lock: + await self.wait(self._commit_gen()) + + async def rollback(self) -> None: + async with self.lock: + await self.wait(self._rollback_gen()) + + @asynccontextmanager + async def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> AsyncIterator[AsyncTransaction]: + """ + Start a context block with a new transaction or nested transaction. + + :rtype: AsyncTransaction + """ + tx = AsyncTransaction(self, savepoint_name, force_rollback) + if self._pipeline: + async with self.pipeline(), tx, self.pipeline(): + yield tx + else: + async with tx: + yield tx + + async def notifies(self) -> AsyncGenerator[Notify, None]: + while True: + async with self.lock: + try: + ns = await self.wait(notifies(self.pgconn)) + except e.Error as ex: + raise ex.with_traceback(None) + enc = pgconn_encoding(self.pgconn) + for pgn in ns: + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + yield n + + @asynccontextmanager + async def pipeline(self) -> AsyncIterator[AsyncPipeline]: + """Context manager to switch the connection into pipeline mode.""" + async with self.lock: + self._check_connection_ok() + + pipeline = self._pipeline + if pipeline is None: + # WARNING: reference loop, broken ahead. + pipeline = self._pipeline = AsyncPipeline(self) + + try: + async with pipeline: + yield pipeline + finally: + if pipeline.level == 0: + async with self.lock: + assert pipeline is self._pipeline + self._pipeline = None + + async def wait(self, gen: PQGen[RV]) -> RV: + try: + return await waiting.wait_async(gen, self.pgconn.socket) + except KeyboardInterrupt: + # TODO: this doesn't seem to work as it does for sync connections + # see tests/test_concurrency_async.py::test_ctrl_c + # In the test, the code doesn't reach this branch. + + # On Ctrl-C, try to cancel the query in the server, otherwise + # otherwise the connection will be stuck in ACTIVE state + c = self.pgconn.get_cancel() + c.cancel() + try: + await waiting.wait_async(gen, self.pgconn.socket) + except e.QueryCanceled: + pass # as expected + raise + + @classmethod + async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + return await waiting.wait_conn_async(gen, timeout) + + def _set_autocommit(self, value: bool) -> None: + self._no_set_async("autocommit") + + async def set_autocommit(self, value: bool) -> None: + """Async version of the `~Connection.autocommit` setter.""" + async with self.lock: + await self.wait(self._set_autocommit_gen(value)) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._no_set_async("isolation_level") + + async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + """Async version of the `~Connection.isolation_level` setter.""" + async with self.lock: + await self.wait(self._set_isolation_level_gen(value)) + + def _set_read_only(self, value: Optional[bool]) -> None: + self._no_set_async("read_only") + + async def set_read_only(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.read_only` setter.""" + async with self.lock: + await self.wait(self._set_read_only_gen(value)) + + def _set_deferrable(self, value: Optional[bool]) -> None: + self._no_set_async("deferrable") + + async def set_deferrable(self, value: Optional[bool]) -> None: + """Async version of the `~Connection.deferrable` setter.""" + async with self.lock: + await self.wait(self._set_deferrable_gen(value)) + + def _no_set_async(self, attribute: str) -> None: + raise AttributeError( + f"'the {attribute!r} property is read-only on async connections:" + f" please use 'await .set_{attribute}()' instead." + ) + + async def tpc_begin(self, xid: Union[Xid, str]) -> None: + async with self.lock: + await self.wait(self._tpc_begin_gen(xid)) + + async def tpc_prepare(self) -> None: + try: + async with self.lock: + await self.wait(self._tpc_prepare_gen()) + except e.ObjectNotInPrerequisiteState as ex: + raise e.NotSupportedError(str(ex)) from None + + async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("commit", xid)) + + async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None: + async with self.lock: + await self.wait(self._tpc_finish_gen("rollback", xid)) + + async def tpc_recover(self) -> List[Xid]: + self._check_tpc() + status = self.info.transaction_status + async with self.cursor(row_factory=args_row(Xid._from_record)) as cur: + await cur.execute(Xid._get_recover_query()) + res = await cur.fetchall() + + if status == IDLE and self.info.transaction_status == INTRANS: + await self.rollback() + + return res diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py new file mode 100644 index 0000000..3b21f83 --- /dev/null +++ b/psycopg/psycopg/conninfo.py @@ -0,0 +1,378 @@ +""" +Functions to manipulate conninfo strings +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import re +import socket +import asyncio +from typing import Any, Dict, List, Optional +from pathlib import Path +from datetime import tzinfo +from functools import lru_cache +from ipaddress import ip_address + +from . import pq +from . import errors as e +from ._tz import get_tzinfo +from ._encodings import pgconn_encoding + + +def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: + """ + Merge a string and keyword params into a single conninfo string. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: A connection string valid for PostgreSQL, with the `!kwargs` + parameters merged. + + Raise `~psycopg.ProgrammingError` if the input doesn't make a valid + conninfo string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + if not conninfo and not kwargs: + return "" + + # If no kwarg specified don't mung the conninfo but check if it's correct. + # Make sure to return a string, not a subtype, to avoid making Liskov sad. + if not kwargs: + _parse_conninfo(conninfo) + return str(conninfo) + + # Override the conninfo with the parameters + # Drop the None arguments + kwargs = {k: v for (k, v) in kwargs.items() if v is not None} + + if conninfo: + tmp = conninfo_to_dict(conninfo) + tmp.update(kwargs) + kwargs = tmp + + conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()) + + # Verify the result is valid + _parse_conninfo(conninfo) + + return conninfo + + +def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: + """ + Convert the `!conninfo` string into a dictionary of parameters. + + :param conninfo: A `connection string`__ as accepted by PostgreSQL. + :param kwargs: Parameters overriding the ones specified in `!conninfo`. + :return: Dictionary with the parameters parsed from `!conninfo` and + `!kwargs`. + + Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection + string. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING + """ + opts = _parse_conninfo(conninfo) + rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} + for k, v in kwargs.items(): + if v is not None: + rv[k] = v + return rv + + +def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: + """ + Verify that `!conninfo` is a valid connection string. + + Raise ProgrammingError if the string is not valid. + + Return the result of pq.Conninfo.parse() on success. + """ + try: + return pq.Conninfo.parse(conninfo.encode()) + except e.OperationalError as ex: + raise e.ProgrammingError(str(ex)) + + +re_escape = re.compile(r"([\\'])") +re_space = re.compile(r"\s") + + +def _param_escape(s: str) -> str: + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: + return "''" + + s = re_escape.sub(r"\\\1", s) + if re_space.search(s): + s = "'" + s + "'" + + return s + + +class ConnectionInfo: + """Allow access to information about the connection.""" + + __module__ = "psycopg" + + def __init__(self, pgconn: pq.abc.PGconn): + self.pgconn = pgconn + + @property + def vendor(self) -> str: + """A string representing the database vendor connected to.""" + return "PostgreSQL" + + @property + def host(self) -> str: + """The server host name of the active connection. See :pq:`PQhost()`.""" + return self._get_pgconn_attr("host") + + @property + def hostaddr(self) -> str: + """The server IP address of the connection. See :pq:`PQhostaddr()`.""" + return self._get_pgconn_attr("hostaddr") + + @property + def port(self) -> int: + """The port of the active connection. See :pq:`PQport()`.""" + return int(self._get_pgconn_attr("port")) + + @property + def dbname(self) -> str: + """The database name of the connection. See :pq:`PQdb()`.""" + return self._get_pgconn_attr("db") + + @property + def user(self) -> str: + """The user name of the connection. See :pq:`PQuser()`.""" + return self._get_pgconn_attr("user") + + @property + def password(self) -> str: + """The password of the connection. See :pq:`PQpass()`.""" + return self._get_pgconn_attr("password") + + @property + def options(self) -> str: + """ + The command-line options passed in the connection request. + See :pq:`PQoptions`. + """ + return self._get_pgconn_attr("options") + + def get_parameters(self) -> Dict[str, str]: + """Return the connection parameters values. + + Return all the parameters set to a non-default value, which might come + either from the connection string and parameters passed to + `~Connection.connect()` or from environment variables. The password + is never returned (you can read it using the `password` attribute). + """ + pyenc = self.encoding + + # Get the known defaults to avoid reporting them + defaults = { + i.keyword: i.compiled + for i in pq.Conninfo.get_defaults() + if i.compiled is not None + } + # Not returned by the libq. Bug? Bet we're using SSH. + defaults.setdefault(b"channel_binding", b"prefer") + defaults[b"passfile"] = str(Path.home() / ".pgpass").encode() + + return { + i.keyword.decode(pyenc): i.val.decode(pyenc) + for i in self.pgconn.info + if i.val is not None + and i.keyword != b"password" + and i.val != defaults.get(i.keyword) + } + + @property + def dsn(self) -> str: + """Return the connection string to connect to the database. + + The string contains all the parameters set to a non-default value, + which might come either from the connection string and parameters + passed to `~Connection.connect()` or from environment variables. The + password is never returned (you can read it using the `password` + attribute). + """ + return make_conninfo(**self.get_parameters()) + + @property + def status(self) -> pq.ConnStatus: + """The status of the connection. See :pq:`PQstatus()`.""" + return pq.ConnStatus(self.pgconn.status) + + @property + def transaction_status(self) -> pq.TransactionStatus: + """ + The current in-transaction status of the session. + See :pq:`PQtransactionStatus()`. + """ + return pq.TransactionStatus(self.pgconn.transaction_status) + + @property + def pipeline_status(self) -> pq.PipelineStatus: + """ + The current pipeline status of the client. + See :pq:`PQpipelineStatus()`. + """ + return pq.PipelineStatus(self.pgconn.pipeline_status) + + def parameter_status(self, param_name: str) -> Optional[str]: + """ + Return a parameter setting of the connection. + + Return `None` is the parameter is unknown. + """ + res = self.pgconn.parameter_status(param_name.encode(self.encoding)) + return res.decode(self.encoding) if res is not None else None + + @property + def server_version(self) -> int: + """ + An integer representing the server version. See :pq:`PQserverVersion()`. + """ + return self.pgconn.server_version + + @property + def backend_pid(self) -> int: + """ + The process ID (PID) of the backend process handling this connection. + See :pq:`PQbackendPID()`. + """ + return self.pgconn.backend_pid + + @property + def error_message(self) -> str: + """ + The error message most recently generated by an operation on the connection. + See :pq:`PQerrorMessage()`. + """ + return self._get_pgconn_attr("error_message") + + @property + def timezone(self) -> tzinfo: + """The Python timezone info of the connection's timezone.""" + return get_tzinfo(self.pgconn) + + @property + def encoding(self) -> str: + """The Python codec name of the connection's client encoding.""" + return pgconn_encoding(self.pgconn) + + def _get_pgconn_attr(self, name: str) -> str: + value: bytes = getattr(self.pgconn, name) + return value.decode(self.encoding) + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups, eventually removing hosts that are + not resolved, keeping the lists of hosts and ports consistent. + + Raise `~psycopg.OperationalError` if connection is not possible (e.g. no + host resolve, inconsistent lists length). + """ + hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) + if hostaddr_arg: + # Already resolved + return params + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + if not host_arg: + # Nothing to resolve + return params + + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") if port_arg else [] + default_port = "5432" + + if len(ports_in) == 1: + # If only one port is specified, the libpq will apply it to all + # the hosts, so don't mangle it. + default_port = ports_in.pop() + + elif len(ports_in) > 1: + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + ports_out = [] + + hosts_out = [] + hostaddr_out = [] + loop = asyncio.get_running_loop() + for i, host in enumerate(hosts_in): + if not host or host.startswith("/") or host[1:2] == ":": + # Local path + hosts_out.append(host) + hostaddr_out.append("") + if ports_in: + ports_out.append(ports_in[i]) + continue + + # If the host is already an ip address don't try to resolve it + if is_ip_address(host): + hosts_out.append(host) + hostaddr_out.append(host) + if ports_in: + ports_out.append(ports_in[i]) + continue + + try: + port = ports_in[i] if ports_in else default_port + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + except OSError as ex: + last_exc = ex + else: + for item in ans: + hosts_out.append(host) + hostaddr_out.append(item[4][0]) + if ports_in: + ports_out.append(ports_in[i]) + + # Throw an exception if no host could be resolved + if not hosts_out: + raise e.OperationalError(str(last_exc)) + + out = params.copy() + out["host"] = ",".join(hosts_out) + out["hostaddr"] = ",".join(hostaddr_out) + if ports_in: + out["port"] = ",".join(ports_out) + + return out + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py new file mode 100644 index 0000000..7514306 --- /dev/null +++ b/psycopg/psycopg/copy.py @@ -0,0 +1,904 @@ +""" +psycopg copy support +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import queue +import struct +import asyncio +import threading +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO +from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING + +from . import pq +from . import adapt +from . import errors as e +from .abc import Buffer, ConnectionType, PQGen, Transformer +from ._compat import create_task +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding +from .generators import copy_from, copy_to, copy_end + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +PY_TEXT = adapt.PyFormat.TEXT +PY_BINARY = adapt.PyFormat.BINARY + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + +ACTIVE = pq.TransactionStatus.ACTIVE + +# Size of data to accumulate before sending it down the network. We fill a +# buffer this size field by field, and when it passes the threshold size +# we ship it, so it may end up being bigger than this. +BUFFER_SIZE = 32 * 1024 + +# Maximum data size we want to queue to send to the libpq copy. Sending a +# buffer too big to be handled can cause an infinite loop in the libpq +# (#255) so we want to split it in more digestable chunks. +MAX_BUFFER_SIZE = 4 * BUFFER_SIZE +# Note: making this buffer too large, e.g. +# MAX_BUFFER_SIZE = 1024 * 1024 +# makes operations *way* slower! Probably triggering some quadraticity +# in the libpq memory management and data sending. + +# Max size of the write queue of buffers. More than that copy will block +# Each buffer should be around BUFFER_SIZE size. +QUEUE_SIZE = 1024 + + +class BaseCopy(Generic[ConnectionType]): + """ + Base implementation for the copy user interface. + + Two subclasses expose real methods with the sync/async differences. + + The difference between the text and binary format is managed by two + different `Formatter` subclasses. + + Writing (the I/O part) is implemented in the subclasses by a `Writer` or + `AsyncWriter` instance. Normally writing implies sending copy data to a + database, but a different writer might be chosen, e.g. to stream data into + a file for later use. + """ + + _Self = TypeVar("_Self", bound="BaseCopy[Any]") + + formatter: "Formatter" + + def __init__( + self, + cursor: "BaseCursor[ConnectionType, Any]", + *, + binary: Optional[bool] = None, + ): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + result = cursor.pgresult + if result: + self._direction = result.status + if self._direction != COPY_IN and self._direction != COPY_OUT: + raise e.ProgrammingError( + "the cursor should have performed a COPY operation;" + f" its status is {pq.ExecStatus(self._direction).name} instead" + ) + else: + self._direction = COPY_IN + + if binary is None: + binary = bool(result and result.binary_tuples) + + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) + if binary: + self.formatter = BinaryFormatter(tx) + else: + self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) + + self._finished = False + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def _enter(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + + def set_types(self, types: Sequence[Union[int, str]]) -> None: + """ + Set the types expected in a COPY operation. + + The types must be specified as a sequence of oid or PostgreSQL type + names (e.g. ``int4``, ``timestamptz[]``). + + This operation overcomes the lack of metadata returned by PostgreSQL + when a COPY operation begins: + + - On :sql:`COPY TO`, `!set_types()` allows to specify what types the + operation returns. If `!set_types()` is not used, the data will be + returned as unparsed strings or bytes instead of Python objects. + + - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the + database expects. This is especially useful in binary copy, because + PostgreSQL will apply no cast rule. + + """ + registry = self.cursor.adapters.types + oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] + + if self._direction == COPY_IN: + self.formatter.transformer.set_dumper_types(oids, self.formatter.format) + else: + self.formatter.transformer.set_loader_types(oids, self.formatter.format) + + # High level copy protocol generators (state change of the Copy object) + + def _read_gen(self) -> PQGen[Buffer]: + if self._finished: + return memoryview(b"") + + res = yield from copy_from(self._pgconn) + if isinstance(res, memoryview): + return res + + # res is the final PGresult + self._finished = True + + # This result is a COMMAND_OK which has info about the number of rows + # returned, but not about the columns, which is instead an information + # that was received on the COPY_OUT result at the beginning of COPY. + # So, don't replace the results in the cursor, just update the rowcount. + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return memoryview(b"") + + def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: + data = yield from self._read_gen() + if not data: + return None + + row = self.formatter.parse_row(data) + if row is None: + # Get the final result to finish the copy operation + yield from self._read_gen() + self._finished = True + return None + + return row + + def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + if not exc: + return + + if self._pgconn.transaction_status != ACTIVE: + # The server has already finished to send copy data. The connection + # is already in a good state. + return + + # Throw a cancel to the server, then consume the rest of the copy data + # (which might or might not have been already transferred entirely to + # the client, so we won't necessary see the exception associated with + # canceling). + self.connection.cancel() + try: + while (yield from self._read_gen()): + pass + except e.QueryCanceled: + pass + + +class Copy(BaseCopy["Connection[Any]"]): + """Manage a :sql:`COPY` operation. + + :param cursor: the cursor where the operation is performed. + :param binary: if `!True`, write binary format. + :param writer: the object to write to destination. If not specified, write + to the `!cursor` connection. + + Choosing `!binary` is not necessary if the cursor has executed a + :sql:`COPY` operation, because the operation result describes the format + too. The parameter is useful when a `!Copy` object is created manually and + no operation is performed on the cursor, such as when using ``writer=``\\ + `~psycopg.copy.FileWriter`. + + """ + + __module__ = "psycopg" + + writer: "Writer" + + def __init__( + self, + cursor: "Cursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["Writer"] = None, + ): + super().__init__(cursor, binary=binary) + if not writer: + writer = LibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.finish(exc_val) + + # End user sync interface + + def __iter__(self) -> Iterator[Buffer]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = self.read() + if not data: + break + yield data + + def read(self) -> Buffer: + """ + Read an unparsed row after a :sql:`COPY TO` operation. + + Return an empty string when the data is finished. + """ + return self.connection.wait(self._read_gen()) + + def rows(self) -> Iterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = self.read_row() + if record is None: + break + yield record + + def read_row(self) -> Optional[Tuple[Any, ...]]: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return self.connection.wait(self._read_row_gen()) + + def write(self, buffer: Union[Buffer, str]) -> None: + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. + + If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. + """ + data = self.formatter.write(buffer) + if data: + self._write(data) + + def write_row(self, row: Sequence[Any]) -> None: + """Write a record to a table after a :sql:`COPY FROM` operation.""" + data = self.formatter.write_row(row) + if data: + self._write(data) + + def finish(self, exc: Optional[BaseException]) -> None: + """Terminate the copy operation and free the resources allocated. + + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + self._write(data) + self.writer.finish(exc) + self._finished = True + else: + self.connection.wait(self._end_copy_out_gen(exc)) + + +class Writer(ABC): + """ + A class to write copy data somewhere. + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + """ + Write some data to destination. + """ + ... + + def finish(self, exc: Optional[BaseException] = None) -> None: + """ + Called when write operations are finished. + + If operations finished with an error, it will be passed to ``exc``. + """ + pass + + +class LibpqWriter(Writer): + """ + A `Writer` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "Cursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class QueuedLibpqDriver(LibpqWriter): + """ + A writer using a buffer to queue data to write to a Postgres database. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "Cursor[Any]"): + super().__init__(cursor) + + self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[threading.Thread] = None + self._worker_error: Optional[BaseException] = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate thread. + """ + try: + while True: + data = self._queue.get(block=True, timeout=24 * 60 * 60) + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = threading.Thread(target=self.worker) + self._worker.daemon = True + self._worker.start() + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def finish(self, exc: Optional[BaseException] = None) -> None: + self._queue.put(b"") + + if self._worker: + self._worker.join() + self._worker = None # break the loop + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + super().finish(exc) + + +class FileWriter(Writer): + """ + A `Writer` to write copy data to a file-like object. + + :param file: the file where to write copy data. It must be open for writing + in binary mode. + """ + + def __init__(self, file: IO[bytes]): + self.file = file + + def write(self, data: Buffer) -> None: + self.file.write(data) # type: ignore[arg-type] + + +class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): + """Manage an asynchronous :sql:`COPY` operation.""" + + __module__ = "psycopg" + + writer: "AsyncWriter" + + def __init__( + self, + cursor: "AsyncCursor[Any]", + *, + binary: Optional[bool] = None, + writer: Optional["AsyncWriter"] = None, + ): + super().__init__(cursor, binary=binary) + + if not writer: + writer = AsyncLibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.finish(exc_val) + + async def __aiter__(self) -> AsyncIterator[Buffer]: + while True: + data = await self.read() + if not data: + break + yield data + + async def read(self) -> Buffer: + return await self.connection.wait(self._read_gen()) + + async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: + while True: + record = await self.read_row() + if record is None: + break + yield record + + async def read_row(self) -> Optional[Tuple[Any, ...]]: + return await self.connection.wait(self._read_row_gen()) + + async def write(self, buffer: Union[Buffer, str]) -> None: + data = self.formatter.write(buffer) + if data: + await self._write(data) + + async def write_row(self, row: Sequence[Any]) -> None: + data = self.formatter.write_row(row) + if data: + await self._write(data) + + async def finish(self, exc: Optional[BaseException]) -> None: + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + await self._write(data) + await self.writer.finish(exc) + self._finished = True + else: + await self.connection.wait(self._end_copy_out_gen(exc)) + + +class AsyncWriter(ABC): + """ + A class to write copy data somewhere (for async connections). + """ + + @abstractmethod + async def write(self, data: Buffer) -> None: + ... + + async def finish(self, exc: Optional[BaseException] = None) -> None: + pass + + +class AsyncLibpqWriter(AsyncWriter): + """ + An `AsyncWriter` to write copy data to a Postgres database. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + async def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + bmsg: Optional[bytes] + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + self.cursor._results = [res] + + +class AsyncQueuedLibpqWriter(AsyncLibpqWriter): + """ + An `AsyncWriter` using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + def __init__(self, cursor: "AsyncCursor[Any]"): + super().__init__(cursor) + + self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE) + self._worker: Optional[asyncio.Future[None]] = None + + async def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value. + + The function is designed to be run in a separate task. + """ + while True: + data = await self._queue.get() + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) + + async def write(self, data: Buffer) -> None: + if not self._worker: + self._worker = create_task(self.worker()) + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + async def finish(self, exc: Optional[BaseException] = None) -> None: + await self._queue.put(b"") + + if self._worker: + await asyncio.gather(self._worker) + self._worker = None # break reference loops if any + + await super().finish(exc) + + +class Formatter(ABC): + """ + A class which understand a copy format (text, binary). + """ + + format: pq.Format + + def __init__(self, transformer: Transformer): + self.transformer = transformer + self._write_buffer = bytearray() + self._row_mode = False # true if the user is using write_row() + + @abstractmethod + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + ... + + @abstractmethod + def write(self, buffer: Union[Buffer, str]) -> Buffer: + ... + + @abstractmethod + def write_row(self, row: Sequence[Any]) -> Buffer: + ... + + @abstractmethod + def end(self) -> Buffer: + ... + + +class TextFormatter(Formatter): + + format = TEXT + + def __init__(self, transformer: Transformer, encoding: str = "utf-8"): + super().__init__(transformer) + self._encoding = encoding + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if data: + return parse_row_text(data, self.transformer) + else: + return None + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + format_row_text(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + return data.encode(self._encoding) + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +class BinaryFormatter(Formatter): + + format = BINARY + + def __init__(self, transformer: Transformer): + super().__init__(transformer) + self._signature_sent = False + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if not self._signature_sent: + if data[: len(_binary_signature)] != _binary_signature: + raise e.DataError( + "binary copy doesn't start with the expected signature" + ) + self._signature_sent = True + data = data[len(_binary_signature) :] + + elif data == _binary_trailer: + return None + + return parse_row_binary(data, self.transformer) + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + if not self._signature_sent: + self._write_buffer += _binary_signature + self._signature_sent = True + + format_row_binary(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + + elif self._row_mode: + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + raise TypeError("cannot copy str data in binary mode: use bytes instead") + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +def _format_row_text( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for copy.""" + if out is None: + out = bytearray() + + if not row: + out += b"\n" + return out + + for item in row: + if item is not None: + dumper = tx.get_dumper(item, PY_TEXT) + b = dumper.dump(item) + out += _dump_re.sub(_dump_sub, b) + else: + out += rb"\N" + out += b"\t" + + out[-1:] = b"\n" + return out + + +def _format_row_binary( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for binary copy.""" + if out is None: + out = bytearray() + + out += _pack_int2(len(row)) + adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) + for b in adapted: + if b is not None: + out += _pack_int4(len(b)) + out += b + else: + out += _binary_null + + return out + + +def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + if not isinstance(data, bytes): + data = bytes(data) + fields = data.split(b"\t") + fields[-1] = fields[-1][:-1] # drop \n + row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] + return tx.load_sequence(row) + + +def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + row: List[Optional[Buffer]] = [] + nfields = _unpack_int2(data, 0)[0] + pos = 2 + for i in range(nfields): + length = _unpack_int4(data, pos)[0] + pos += 4 + if length >= 0: + row.append(data[pos : pos + length]) + pos += length + else: + row.append(None) + + return tx.load_sequence(row) + + +_pack_int2 = struct.Struct("!h").pack +_pack_int4 = struct.Struct("!i").pack +_unpack_int2 = struct.Struct("!h").unpack_from +_unpack_int4 = struct.Struct("!i").unpack_from + +_binary_signature = ( + b"PGCOPY\n\xff\r\n\0" # Signature + b"\x00\x00\x00\x00" # flags + b"\x00\x00\x00\x00" # extra length +) +_binary_trailer = b"\xff\xff" +_binary_null = b"\xff\xff\xff\xff" + +_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_dump_repl = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", +} + + +def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: + return __map[m.group(0)] + + +_load_re = re.compile(b"\\\\[btnvfr\\\\]") +_load_repl = {v: k for k, v in _dump_repl.items()} + + +def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: + return __map[m.group(0)] + + +# Override functions with fast versions if available +if _psycopg: + format_row_text = _psycopg.format_row_text + format_row_binary = _psycopg.format_row_binary + parse_row_text = _psycopg.parse_row_text + parse_row_binary = _psycopg.parse_row_binary + +else: + format_row_text = _format_row_text + format_row_binary = _format_row_binary + parse_row_text = _parse_row_text + parse_row_binary = _parse_row_binary diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py new file mode 100644 index 0000000..323903a --- /dev/null +++ b/psycopg/psycopg/crdb/__init__.py @@ -0,0 +1,19 @@ +""" +CockroachDB support package. +""" + +# Copyright (C) 2022 The Psycopg Team + +from . import _types +from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo + +adapters = _types.adapters # exposed by the package +connect = CrdbConnection.connect + +_types.register_crdb_adapters(adapters) + +__all__ = [ + "AsyncCrdbConnection", + "CrdbConnection", + "CrdbConnectionInfo", +] diff --git a/psycopg/psycopg/crdb/_types.py b/psycopg/psycopg/crdb/_types.py new file mode 100644 index 0000000..5311e05 --- /dev/null +++ b/psycopg/psycopg/crdb/_types.py @@ -0,0 +1,163 @@ +""" +Types configuration specific for CockroachDB. +""" + +# Copyright (C) 2022 The Psycopg Team + +from enum import Enum +from .._typeinfo import TypeInfo, TypesRegistry + +from ..abc import AdaptContext, NoneType +from ..postgres import TEXT_OID +from .._adapters_map import AdaptersMap +from ..types.enum import EnumDumper, EnumBinaryDumper +from ..types.none import NoneDumper + +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + + +class CrdbEnumDumper(EnumDumper): + oid = TEXT_OID + + +class CrdbEnumBinaryDumper(EnumBinaryDumper): + oid = TEXT_OID + + +class CrdbNoneDumper(NoneDumper): + oid = TEXT_OID + + +def register_postgres_adapters(context: AdaptContext) -> None: + # Same adapters used by PostgreSQL, or a good starting point for customization + + from ..types import array, bool, composite, datetime + from ..types import numeric, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + numeric.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) + + +def register_crdb_adapters(context: AdaptContext) -> None: + from .. import dbapi20 + from ..types import array + + register_postgres_adapters(context) + + # String must come after enum to map text oid -> string dumper + register_crdb_enum_adapters(context) + register_crdb_string_adapters(context) + register_crdb_json_adapters(context) + register_crdb_net_adapters(context) + register_crdb_none_adapters(context) + + dbapi20.register_dbapi20_adapters(adapters) + + array.register_all_arrays(adapters) + + +def register_crdb_string_adapters(context: AdaptContext) -> None: + from ..types import string + + # Dump strings with text oid instead of unknown. + # Unlike PostgreSQL, CRDB seems able to cast text to most types. + context.adapters.register_dumper(str, string.StrDumper) + context.adapters.register_dumper(str, string.StrBinaryDumper) + + +def register_crdb_enum_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper) + context.adapters.register_dumper(Enum, CrdbEnumDumper) + + +def register_crdb_json_adapters(context: AdaptContext) -> None: + from ..types import json + + adapters = context.adapters + + # CRDB doesn't have json/jsonb: both names map to the jsonb oid + adapters.register_dumper(json.Json, json.JsonbBinaryDumper) + adapters.register_dumper(json.Json, json.JsonbDumper) + + adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper) + adapters.register_dumper(json.Jsonb, json.JsonbDumper) + + adapters.register_loader("json", json.JsonLoader) + adapters.register_loader("jsonb", json.JsonbLoader) + adapters.register_loader("json", json.JsonBinaryLoader) + adapters.register_loader("jsonb", json.JsonbBinaryLoader) + + +def register_crdb_net_adapters(context: AdaptContext) -> None: + from ..types import net + + adapters = context.adapters + + adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper) + adapters.register_dumper(None, net.InetBinaryDumper) + adapters.register_loader("inet", net.InetLoader) + adapters.register_loader("inet", net.InetBinaryLoader) + + +def register_crdb_none_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, CrdbNoneDumper) + + +for t in [ + TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb. + TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8 + TypeInfo('"char"', 18, 1002), # special case, not generated + # autogenerated: start + # Generated from CockroachDB 22.1.0 + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("record", 2249, 2287), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("unknown", 705, 0), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + # autogenerated: end +]: + types.add(t) diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py new file mode 100644 index 0000000..6e79ed1 --- /dev/null +++ b/psycopg/psycopg/crdb/connection.py @@ -0,0 +1,186 @@ +""" +CockroachDB-specific connections. +""" + +# Copyright (C) 2022 The Psycopg Team + +import re +from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING + +from .. import errors as e +from ..abc import AdaptContext +from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow +from ..conninfo import ConnectionInfo +from ..connection import Connection +from .._adapters_map import AdaptersMap +from ..connection_async import AsyncConnection +from ._types import adapters + +if TYPE_CHECKING: + from ..pq.abc import PGconn + from ..cursor import Cursor + from ..cursor_async import AsyncCursor + + +class _CrdbConnectionMixin: + + _adapters: Optional[AdaptersMap] + pgconn: "PGconn" + + @classmethod + def is_crdb( + cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"] + ) -> bool: + """ + Return `!True` if the server connected to `!conn` is CockroachDB. + """ + if isinstance(conn, (Connection, AsyncConnection)): + conn = conn.pgconn + + return bool(conn.parameter_status(b"crdb_version")) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + # By default, use CockroachDB adapters map + self._adapters = AdaptersMap(adapters) + + return self._adapters + + @property + def info(self) -> "CrdbConnectionInfo": + return CrdbConnectionInfo(self.pgconn) + + def _check_tpc(self) -> None: + if self.is_crdb(self.pgconn): + raise e.NotSupportedError("CockroachDB doesn't support prepared statements") + + +class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): + """ + Wrapper for a connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[Row]": + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[TupleRow]": + ... + + @classmethod + def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]": + """ + Connect to a database server and return a new `CrdbConnection` instance. + """ + return super().connect(conninfo, **kwargs) # type: ignore[return-value] + + +class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]): + """ + Wrapper for an async connection to a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[Row]": + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[TupleRow]": + ... + + @classmethod + async def connect( + cls, conninfo: str = "", **kwargs: Any + ) -> "AsyncCrdbConnection[Any]": + return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return] + + +class CrdbConnectionInfo(ConnectionInfo): + """ + `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database. + """ + + __module__ = "psycopg.crdb" + + @property + def vendor(self) -> str: + return "CockroachDB" + + @property + def server_version(self) -> int: + """ + Return the CockroachDB server version connected. + + Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210). + """ + sver = self.parameter_status("crdb_version") + if not sver: + raise e.InternalError("'crdb_version' parameter status not set") + + ver = self.parse_crdb_version(sver) + if ver is None: + raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}") + + return ver + + @classmethod + def parse_crdb_version(self, sver: str) -> Optional[int]: + m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) + if not m: + return None + + return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py new file mode 100644 index 0000000..42c3804 --- /dev/null +++ b/psycopg/psycopg/cursor.py @@ -0,0 +1,921 @@ +""" +psycopg cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from functools import partial +from types import TracebackType +from typing import Any, Generic, Iterable, Iterator, List +from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar +from typing import overload, TYPE_CHECKING +from contextlib import contextmanager + +from . import pq +from . import adapt +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .copy import Copy, Writer as CopyWriter +from .rows import Row, RowMaker, RowFactory +from ._column import Column +from ._queries import PostgresQuery, PostgresClientQuery +from ._pipeline import Pipeline +from ._encodings import pgconn_encoding +from ._preparing import Prepare +from .generators import execute, fetch, send + +if TYPE_CHECKING: + from .abc import Transformer + from .pq.abc import PGconn, PGresult + from .connection import Connection + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE +PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class BaseCursor(Generic[ConnectionType, Row]): + __slots__ = """ + _conn format _adapters arraysize _closed _results pgresult _pos + _iresult _rowcount _query _tx _last_query _row_factory _make_row + _pgconn _execmany_returning + __weakref__ + """.split() + + ExecStatus = pq.ExecStatus + + _tx: "Transformer" + _make_row: RowMaker[Row] + _pgconn: "PGconn" + + def __init__(self, connection: ConnectionType): + self._conn = connection + self.format = TEXT + self._pgconn = connection.pgconn + self._adapters = adapt.AdaptersMap(connection.adapters) + self.arraysize = 1 + self._closed = False + self._last_query: Optional[Query] = None + self._reset() + + def _reset(self, reset_query: bool = True) -> None: + self._results: List["PGresult"] = [] + self.pgresult: Optional["PGresult"] = None + self._pos = 0 + self._iresult = 0 + self._rowcount = -1 + self._query: Optional[PostgresQuery] + # None if executemany() not executing, True/False according to returning state + self._execmany_returning: Optional[bool] = None + if reset_query: + self._query = None + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self._pgconn) + if self._closed: + status = "closed" + elif self.pgresult: + status = pq.ExecStatus(self.pgresult.status).name + else: + status = "no result" + return f"<{cls} [{status}] {info} at 0x{id(self):x}>" + + @property + def connection(self) -> ConnectionType: + """The connection this cursor is using.""" + return self._conn + + @property + def adapters(self) -> adapt.AdaptersMap: + return self._adapters + + @property + def closed(self) -> bool: + """`True` if the cursor is closed.""" + return self._closed + + @property + def description(self) -> Optional[List[Column]]: + """ + A list of `Column` objects describing the current resultset. + + `!None` if the current resultset didn't return tuples. + """ + res = self.pgresult + + # We return columns if we have nfields, but also if we don't but + # the query said we got tuples (mostly to handle the super useful + # query "SELECT ;" + if res and ( + res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE + ): + return [Column(self, i) for i in range(res.nfields)] + else: + return None + + @property + def rowcount(self) -> int: + """Number of records affected by the precedent operation.""" + return self._rowcount + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + tuples = self.pgresult and self.pgresult.status == TUPLES_OK + return self._pos if tuples else None + + def setinputsizes(self, sizes: Sequence[Any]) -> None: + # no-op + pass + + def setoutputsize(self, size: Any, column: Optional[int] = None) -> None: + # no-op + pass + + def nextset(self) -> Optional[bool]: + """ + Move to the result set of the next query executed through `executemany()` + or to the next result set if `execute()` returned more than one. + + Return `!True` if a new result is available, which will be the one + methods `!fetch*()` will operate on. + """ + if self._iresult < len(self._results) - 1: + self._select_current_result(self._iresult + 1) + return True + else: + return None + + @property + def statusmessage(self) -> Optional[str]: + """ + The command status tag from the last SQL command executed. + + `!None` if the cursor doesn't have a result available. + """ + msg = self.pgresult.command_status if self.pgresult else None + return msg.decode() if msg else None + + def _make_row_maker(self) -> RowMaker[Row]: + raise NotImplementedError + + # + # Generators for the high level operations on the cursor + # + # Like for sync/async connections, these are implemented as generators + # so that different concurrency strategies (threads,asyncio) can use their + # own way of waiting (or better, `connection.wait()`). + # + + def _execute_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `Cursor.execute()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + results = yield from self._maybe_prepare_gen( + pgq, prepare=prepare, binary=binary + ) + if self._conn._pipeline: + yield from self._conn._pipeline._communicate_gen() + else: + assert results is not None + self._check_results(results) + self._results = results + self._select_current_result(0) + + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines available. + """ + pipeline = self._conn._pipeline + assert pipeline + + yield from self._start_query(query) + self._rowcount = 0 + + assert self._execmany_returning is None + self._execmany_returning = returning + + first = True + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + yield from self._maybe_prepare_gen(pgq, prepare=True) + yield from pipeline._communicate_gen() + + self._last_query = query + + if returning: + yield from pipeline._fetch_gen(flush=True) + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _executemany_gen_no_pipeline( + self, query: Query, params_seq: Iterable[Params], returning: bool + ) -> PQGen[None]: + """ + Generator implementing `Cursor.executemany()` with pipelines not available. + """ + yield from self._start_query(query) + first = True + nrows = 0 + for params in params_seq: + if first: + pgq = self._convert_query(query, params) + self._query = pgq + first = False + else: + pgq.dump(params) + + results = yield from self._maybe_prepare_gen(pgq, prepare=True) + assert results is not None + self._check_results(results) + if returning: + self._results.extend(results) + + for res in results: + nrows += res.command_tuples or 0 + + if self._results: + self._select_current_result(0) + + # Override rowcount for the first result. Calls to nextset() will change + # it to the value of that result only, but we hope nobody will notice. + # You haven't read this comment. + self._rowcount = nrows + self._last_query = query + + for cmd in self._conn._prepared.get_maintenance_commands(): + yield from self._conn._exec_command(cmd) + + def _maybe_prepare_gen( + self, + pgq: PostgresQuery, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> PQGen[Optional[List["PGresult"]]]: + # Check if the query is prepared or needs preparing + prep, name = self._get_prepared(pgq, prepare) + if prep is Prepare.NO: + # The query must be executed without preparing + self._execute_send(pgq, binary=binary) + else: + # If the query is not already prepared, prepare it. + if prep is Prepare.SHOULD: + self._send_prepare(name, pgq) + if not self._conn._pipeline: + (result,) = yield from execute(self._pgconn) + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + # Then execute it. + self._send_query_prepared(name, pgq, binary=binary) + + # Update the prepare state of the query. + # If an operation requires to flush our prepared statements cache, + # it will be added to the maintenance commands to execute later. + key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name) + + if self._conn._pipeline: + queued = None + if key is not None: + queued = (key, prep, name) + self._conn._pipeline.result_queue.append((self, queued)) + return None + + # run the query + results = yield from execute(self._pgconn) + + if key is not None: + self._conn._prepared.validate(key, prep, name, results) + + return results + + def _get_prepared( + self, pgq: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + return self._conn._prepared.get(pgq, prepare) + + def _stream_send_gen( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator to send the query for `Cursor.stream()`.""" + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, binary=binary, force_extended=True) + self._pgconn.set_single_row_mode() + self._last_query = query + yield from send(self._pgconn) + + def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]: + res = yield from fetch(self._pgconn) + if res is None: + return None + + status = res.status + if status == SINGLE_TUPLE: + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=first) + if first: + self._make_row = self._make_row_maker() + return res + + elif status == TUPLES_OK or status == COMMAND_OK: + # End of single row results + while res: + res = yield from fetch(self._pgconn) + if status != TUPLES_OK: + raise e.ProgrammingError( + "the operation in stream() didn't produce a result" + ) + return None + + else: + # Errors, unexpected values + return self._raise_for_result(res) + + def _start_query(self, query: Optional[Query] = None) -> PQGen[None]: + """Generator to start the processing of a query. + + It is implemented as generator because it may send additional queries, + such as `begin`. + """ + if self.closed: + raise e.InterfaceError("the cursor is closed") + + self._reset() + if not self._last_query or (self._last_query is not query): + self._last_query = None + self._tx = adapt.Transformer(self) + yield from self._conn._start_query() + + def _start_copy_gen( + self, statement: Query, params: Optional[Params] = None + ) -> PQGen[None]: + """Generator implementing sending a command for `Cursor.copy().""" + + # The connection gets in an unrecoverable state if we attempt COPY in + # pipeline mode. Forbid it explicitly. + if self._conn._pipeline: + raise e.NotSupportedError("COPY cannot be used in pipeline mode") + + yield from self._start_query() + + # Merge the params client-side + if params: + pgq = PostgresClientQuery(self._tx) + pgq.convert(statement, params) + statement = pgq.query + + query = self._convert_query(statement) + + self._execute_send(query, binary=False) + results = yield from execute(self._pgconn) + if len(results) != 1: + raise e.ProgrammingError("COPY cannot be mixed with other operations") + + self._check_copy_result(results[0]) + self._results = results + self._select_current_result(0) + + def _execute_send( + self, + query: PostgresQuery, + *, + force_extended: bool = False, + binary: Optional[bool] = None, + ) -> None: + """ + Implement part of execute() before waiting common to sync and async. + + This is not a generator, but a normal non-blocking function. + """ + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + self._query = query + + if self._conn._pipeline: + # In pipeline mode always use PQsendQueryParams - see #314 + # Multiple statements in the same query are not allowed anyway. + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_params, + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + ) + elif force_extended or query.params or fmt == BINARY: + self._pgconn.send_query_params( + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, + result_format=fmt, + ) + else: + # If we can, let's use simple query protocol, + # as it can execute more than one statement in a single query. + self._pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresQuery(self._tx) + pgq.convert(query, params) + return pgq + + def _check_results(self, results: List["PGresult"]) -> None: + """ + Verify that the results of a query are valid. + + Verify that the query returned at least one result and that they all + represent a valid result from the database. + """ + if not results: + raise e.InternalError("got no result from the query") + + for res in results: + status = res.status + if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY: + self._raise_for_result(res) + + def _raise_for_result(self, result: "PGresult") -> NoReturn: + """ + Raise an appropriate error message for an unexpected database result + """ + status = result.status + if status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + raise e.ProgrammingError( + "COPY cannot be used with this method; use copy() instead" + ) + else: + raise e.InternalError( + "unexpected result status from query:" f" {pq.ExecStatus(status).name}" + ) + + def _select_current_result( + self, i: int, format: Optional[pq.Format] = None + ) -> None: + """ + Select one of the results in the cursor as the active one. + """ + self._iresult = i + res = self.pgresult = self._results[i] + + # Note: the only reason to override format is to correctly set + # binary loaders on server-side cursors, because send_describe_portal + # only returns a text result. + self._tx.set_pgresult(res, format=format) + + self._pos = 0 + + if res.status == TUPLES_OK: + self._rowcount = self.pgresult.ntuples + + # COPY_OUT has never info about nrows. We need such result for the + # columns in order to return a `description`, but not overwrite the + # cursor rowcount (which was set by the Copy object). + elif res.status != COPY_OUT: + nrows = self.pgresult.command_tuples + self._rowcount = nrows if nrows is not None else -1 + + self._make_row = self._make_row_maker() + + def _set_results_from_pipeline(self, results: List["PGresult"]) -> None: + self._check_results(results) + first_batch = not self._results + + if self._execmany_returning is None: + # Received from execute() + self._results.extend(results) + if first_batch: + self._select_current_result(0) + + else: + # Received from executemany() + if self._execmany_returning: + self._results.extend(results) + if first_batch: + self._select_current_result(0) + self._rowcount = 0 + + # Override rowcount for the first result. Calls to nextset() will + # change it to the value of that result only, but we hope nobody + # will notice. + # You haven't read this comment. + if self._rowcount < 0: + self._rowcount = 0 + for res in results: + self._rowcount += res.command_tuples or 0 + + def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_prepare, + name, + query.query, + param_types=query.types, + ) + ) + self._conn._pipeline.result_queue.append(None) + else: + self._pgconn.send_prepare(name, query.query, param_types=query.types) + + def _send_query_prepared( + self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None + ) -> None: + if binary is None: + fmt = self.format + else: + fmt = BINARY if binary else TEXT + + if self._conn._pipeline: + self._conn._pipeline.command_queue.append( + partial( + self._pgconn.send_query_prepared, + name, + pgq.params, + param_formats=pgq.formats, + result_format=fmt, + ) + ) + else: + self._pgconn.send_query_prepared( + name, pgq.params, param_formats=pgq.formats, result_format=fmt + ) + + def _check_result_for_fetch(self) -> None: + if self.closed: + raise e.InterfaceError("the cursor is closed") + res = self.pgresult + if not res: + raise e.ProgrammingError("no result available") + + status = res.status + if status == TUPLES_OK: + return + elif status == FATAL_ERROR: + raise e.error_from_result(res, encoding=self._encoding) + elif status == PIPELINE_ABORTED: + raise e.PipelineAborted("pipeline aborted") + else: + raise e.ProgrammingError("the last operation didn't produce a result") + + def _check_copy_result(self, result: "PGresult") -> None: + """ + Check that the value returned in a copy() operation is a legit COPY. + """ + status = result.status + if status == COPY_IN or status == COPY_OUT: + return + elif status == FATAL_ERROR: + raise e.error_from_result(result, encoding=self._encoding) + else: + raise e.ProgrammingError( + "copy() should be used only with COPY ... TO STDOUT or COPY ..." + f" FROM STDIN statements, got {pq.ExecStatus(status).name}" + ) + + def _scroll(self, value: int, mode: str) -> None: + self._check_result_for_fetch() + assert self.pgresult + if mode == "relative": + newpos = self._pos + value + elif mode == "absolute": + newpos = value + else: + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + if not 0 <= newpos < self.pgresult.ntuples: + raise IndexError("position out of bound") + self._pos = newpos + + def _close(self) -> None: + """Non-blocking part of closing. Common to sync/async.""" + # Don't reset the query because it may be useful to investigate after + # an error. + self._reset(reset_query=False) + self._closed = True + + @property + def _encoding(self) -> str: + return pgconn_encoding(self._pgconn) + + +class Cursor(BaseCursor["Connection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="Cursor[Any]") + + @overload + def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): + ... + + @overload + def __init__( + self: "Cursor[Row]", + connection: "Connection[Any]", + *, + row_factory: RowFactory[Row], + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + *, + row_factory: Optional[RowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + def __enter__(self: _Self) -> _Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + self._close() + + @property + def row_factory(self) -> RowFactory[Row]: + """Writable attribute to control how result rows are formed.""" + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: RowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + """ + Execute a query or command to the database. + """ + try: + with self._conn.lock: + self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + """ + Execute the same command with a sequence of input data. + """ + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + with self._conn.lock: + p = self._conn._pipeline + if p: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + with self._conn.pipeline(), self._conn.lock: + self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + with self._conn.lock: + self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> Iterator[Row]: + """ + Iterate row-by-row on a result from the database. + """ + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + with self._conn.lock: + + try: + self._conn.wait(self._stream_send_gen(query, params, binary=binary)) + first = True + while self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while self._conn.wait(self._stream_fetchone_gen(first=False)): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + def fetchone(self) -> Optional[Row]: + """ + Return the next record from the current recordset. + + Return `!None` the recordset is finished. + + :rtype: Optional[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + record = self._tx.load_row(self._pos, self._make_row) + if record is not None: + self._pos += 1 + return record + + def fetchmany(self, size: int = 0) -> List[Row]: + """ + Return the next `!size` records from the current recordset. + + `!size` default to `!self.arraysize` if not specified. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + def fetchall(self) -> List[Row]: + """ + Return all the remaining records from the current recordset. + + :rtype: Sequence[Row], with Row defined by `row_factory` + """ + self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + def __iter__(self) -> Iterator[Row]: + self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + def scroll(self, value: int, mode: str = "relative") -> None: + """ + Move the cursor in the result set to a new position according to mode. + + If `!mode` is ``'relative'`` (default), `!value` is taken as offset to + the current position in the result set; if set to ``'absolute'``, + `!value` states an absolute target position. + + Raise `!IndexError` in case a scroll operation would leave the result + set. In this case the position will not change. + """ + self._fetch_pipeline() + self._scroll(value, mode) + + @contextmanager + def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[CopyWriter] = None, + ) -> Iterator[Copy]: + """ + Initiate a :sql:`COPY` operation and return an object to manage it. + + :rtype: Copy + """ + try: + with self._conn.lock: + self._conn.wait(self._start_copy_gen(statement, params)) + + with Copy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + # If a fresher result has been set on the cursor by the Copy object, + # read its properties (especially rowcount). + self._select_current_result(0) + + def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + with self._conn.lock: + self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py new file mode 100644 index 0000000..8971d40 --- /dev/null +++ b/psycopg/psycopg/cursor_async.py @@ -0,0 +1,250 @@ +""" +psycopg async cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from types import TracebackType +from typing import Any, AsyncIterator, Iterable, List +from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload +from contextlib import asynccontextmanager + +from . import pq +from . import errors as e +from .abc import Query, Params +from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter +from .rows import Row, RowMaker, AsyncRowFactory +from .cursor import BaseCursor +from ._pipeline import Pipeline + +if TYPE_CHECKING: + from .connection_async import AsyncConnection + +ACTIVE = pq.TransactionStatus.ACTIVE + + +class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncCursor[Any]") + + @overload + def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): + ... + + @overload + def __init__( + self: "AsyncCursor[Row]", + connection: "AsyncConnection[Any]", + *, + row_factory: AsyncRowFactory[Row], + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + ): + super().__init__(connection) + self._row_factory = row_factory or connection.row_factory + + async def __aenter__(self: _Self) -> _Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + async def close(self) -> None: + self._close() + + @property + def row_factory(self) -> AsyncRowFactory[Row]: + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None: + self._row_factory = row_factory + if self.pgresult: + self._make_row = row_factory(self) + + def _make_row_maker(self) -> RowMaker[Row]: + return self._row_factory(self) + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + prepare: Optional[bool] = None, + binary: Optional[bool] = None, + ) -> _Self: + try: + async with self._conn.lock: + await self._conn.wait( + self._execute_gen(query, params, prepare=prepare, binary=binary) + ) + except e.Error as ex: + raise ex.with_traceback(None) + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = False, + ) -> None: + try: + if Pipeline.is_supported(): + # If there is already a pipeline, ride it, in order to avoid + # sending unnecessary Sync. + async with self._conn.lock: + p = self._conn._pipeline + if p: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + # Otherwise, make a new one + if not p: + async with self._conn.pipeline(), self._conn.lock: + await self._conn.wait( + self._executemany_gen_pipeline(query, params_seq, returning) + ) + else: + await self._conn.wait( + self._executemany_gen_no_pipeline(query, params_seq, returning) + ) + except e.Error as ex: + raise ex.with_traceback(None) + + async def stream( + self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + ) -> AsyncIterator[Row]: + if self._pgconn.pipeline_status: + raise e.ProgrammingError("stream() cannot be used in pipeline mode") + + async with self._conn.lock: + + try: + await self._conn.wait( + self._stream_send_gen(query, params, binary=binary) + ) + first = True + while await self._conn.wait(self._stream_fetchone_gen(first)): + # We know that, if we got a result, it has a single row. + rec: Row = self._tx.load_row(0, self._make_row) # type: ignore + yield rec + first = False + + except e.Error as ex: + raise ex.with_traceback(None) + + finally: + if self._pgconn.transaction_status == ACTIVE: + # Try to cancel the query, then consume the results + # already received. + self._conn.cancel() + try: + while await self._conn.wait( + self._stream_fetchone_gen(first=False) + ): + pass + except Exception: + pass + + # Try to get out of ACTIVE state. Just do a single attempt, which + # should work to recover from an error or query cancelled. + try: + await self._conn.wait(self._stream_fetchone_gen(first=False)) + except Exception: + pass + + async def fetchone(self) -> Optional[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + rv = self._tx.load_row(self._pos, self._make_row) + if rv is not None: + self._pos += 1 + return rv + + async def fetchmany(self, size: int = 0) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + + if not size: + size = self.arraysize + records = self._tx.load_rows( + self._pos, + min(self._pos + size, self.pgresult.ntuples), + self._make_row, + ) + self._pos += len(records) + return records + + async def fetchall(self) -> List[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + assert self.pgresult + records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row) + self._pos = self.pgresult.ntuples + return records + + async def __aiter__(self) -> AsyncIterator[Row]: + await self._fetch_pipeline() + self._check_result_for_fetch() + + def load(pos: int) -> Optional[Row]: + return self._tx.load_row(pos, self._make_row) + + while True: + row = load(self._pos) + if row is None: + break + self._pos += 1 + yield row + + async def scroll(self, value: int, mode: str = "relative") -> None: + self._scroll(value, mode) + + @asynccontextmanager + async def copy( + self, + statement: Query, + params: Optional[Params] = None, + *, + writer: Optional[AsyncCopyWriter] = None, + ) -> AsyncIterator[AsyncCopy]: + """ + :rtype: AsyncCopy + """ + try: + async with self._conn.lock: + await self._conn.wait(self._start_copy_gen(statement, params)) + + async with AsyncCopy(self, writer=writer) as copy: + yield copy + except e.Error as ex: + raise ex.with_traceback(None) + + self._select_current_result(0) + + async def _fetch_pipeline(self) -> None: + if ( + self._execmany_returning is not False + and not self.pgresult + and self._conn._pipeline + ): + async with self._conn.lock: + await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True)) diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py new file mode 100644 index 0000000..3c3d8b7 --- /dev/null +++ b/psycopg/psycopg/dbapi20.py @@ -0,0 +1,112 @@ +""" +Compatibility objects with DBAPI 2.0 +""" + +# Copyright (C) 2020 The Psycopg Team + +import time +import datetime as dt +from math import floor +from typing import Any, Sequence, Union + +from . import postgres +from .abc import AdaptContext, Buffer +from .types.string import BytesDumper, BytesBinaryDumper + + +class DBAPITypeObject: + def __init__(self, name: str, type_names: Sequence[str]): + self.name = name + self.values = tuple(postgres.types[n].oid for n in type_names) + + def __repr__(self) -> str: + return f"psycopg.{self.name}" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, int): + return other in self.values + else: + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, int): + return other not in self.values + else: + return NotImplemented + + +BINARY = DBAPITypeObject("BINARY", ("bytea",)) +DATETIME = DBAPITypeObject( + "DATETIME", "timestamp timestamptz date time timetz interval".split() +) +NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split()) +ROWID = DBAPITypeObject("ROWID", ("oid",)) +STRING = DBAPITypeObject("STRING", "text varchar bpchar".split()) + + +class Binary: + def __init__(self, obj: Any): + self.obj = obj + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)" + return f"{self.__class__.__name__}({sobj})" + + +class BinaryBinaryDumper(BytesBinaryDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +class BinaryTextDumper(BytesDumper): + def dump(self, obj: Union[Buffer, Binary]) -> Buffer: + if isinstance(obj, Binary): + return super().dump(obj.obj) + else: + return super().dump(obj) + + +def Date(year: int, month: int, day: int) -> dt.date: + return dt.date(year, month, day) + + +def DateFromTicks(ticks: float) -> dt.date: + return TimestampFromTicks(ticks).date() + + +def Time(hour: int, minute: int, second: int) -> dt.time: + return dt.time(hour, minute, second) + + +def TimeFromTicks(ticks: float) -> dt.time: + return TimestampFromTicks(ticks).time() + + +def Timestamp( + year: int, month: int, day: int, hour: int, minute: int, second: int +) -> dt.datetime: + return dt.datetime(year, month, day, hour, minute, second) + + +def TimestampFromTicks(ticks: float) -> dt.datetime: + secs = floor(ticks) + frac = ticks - secs + t = time.localtime(ticks) + tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff)) + rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo) + return rv + + +def register_dbapi20_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Binary, BinaryTextDumper) + adapters.register_dumper(Binary, BinaryBinaryDumper) + + # Make them also the default dumpers when dumping by bytea oid + adapters.register_dumper(None, BinaryTextDumper) + adapters.register_dumper(None, BinaryBinaryDumper) diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py new file mode 100644 index 0000000..e176954 --- /dev/null +++ b/psycopg/psycopg/errors.py @@ -0,0 +1,1535 @@ +""" +psycopg exceptions + +DBAPI-defined Exceptions are defined in the following hierarchy:: + + Exceptions + |__Warning + |__Error + |__InterfaceError + |__DatabaseError + |__DataError + |__OperationalError + |__IntegrityError + |__InternalError + |__ProgrammingError + |__NotSupportedError +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing_extensions import TypeAlias + +from .pq.abc import PGconn, PGresult +from .pq._enums import DiagnosticField +from ._compat import TypeGuard + +ErrorInfo: TypeAlias = Union[None, PGresult, Dict[int, Optional[bytes]]] + +_sqlcodes: Dict[str, "Type[Error]"] = {} + + +class Warning(Exception): + """ + Exception raised for important warnings. + + Defined for DBAPI compatibility, but never raised by ``psycopg``. + """ + + __module__ = "psycopg" + + +class Error(Exception): + """ + Base exception for all the errors psycopg will raise. + + Exception that is the base class of all other error exceptions. You can + use this to catch all errors with one single `!except` statement. + + This exception is guaranteed to be picklable. + """ + + __module__ = "psycopg" + + sqlstate: Optional[str] = None + + def __init__( + self, + *args: Sequence[Any], + info: ErrorInfo = None, + encoding: str = "utf-8", + pgconn: Optional[PGconn] = None + ): + super().__init__(*args) + self._info = info + self._encoding = encoding + self._pgconn = pgconn + + # Handle sqlstate codes for which we don't have a class. + if not self.sqlstate and info: + self.sqlstate = self.diag.sqlstate + + @property + def pgconn(self) -> Optional[PGconn]: + """The connection object, if the error was raised from a connection attempt. + + :rtype: Optional[psycopg.pq.PGconn] + """ + return self._pgconn if self._pgconn else None + + @property + def pgresult(self) -> Optional[PGresult]: + """The result object, if the exception was raised after a failed query. + + :rtype: Optional[psycopg.pq.PGresult] + """ + return self._info if _is_pgresult(self._info) else None + + @property + def diag(self) -> "Diagnostic": + """ + A `Diagnostic` object to inspect details of the errors from the database. + """ + return Diagnostic(self._info, encoding=self._encoding) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + # To make the exception picklable + res[2]["_info"] = _info_to_dict(self._info) + res[2]["_pgconn"] = None + + return res + + +class InterfaceError(Error): + """ + An error related to the database interface rather than the database itself. + """ + + __module__ = "psycopg" + + +class DatabaseError(Error): + """ + Exception raised for errors that are related to the database. + """ + + __module__ = "psycopg" + + def __init_subclass__(cls, code: Optional[str] = None, name: Optional[str] = None): + if code: + _sqlcodes[code] = cls + cls.sqlstate = code + if name: + _sqlcodes[name] = cls + + +class DataError(DatabaseError): + """ + An error caused by problems with the processed data. + + Examples may be division by zero, numeric value out of range, etc. + """ + + __module__ = "psycopg" + + +class OperationalError(DatabaseError): + """ + An error related to the database's operation. + + These errors are not necessarily under the control of the programmer, e.g. + an unexpected disconnect occurs, the data source name is not found, a + transaction could not be processed, a memory allocation error occurred + during processing, etc. + """ + + __module__ = "psycopg" + + +class IntegrityError(DatabaseError): + """ + An error caused when the relational integrity of the database is affected. + + An example may be a foreign key check failed. + """ + + __module__ = "psycopg" + + +class InternalError(DatabaseError): + """ + An error generated when the database encounters an internal error, + + Examples could be the cursor is not valid anymore, the transaction is out + of sync, etc. + """ + + __module__ = "psycopg" + + +class ProgrammingError(DatabaseError): + """ + Exception raised for programming errors + + Examples may be table not found or already exists, syntax error in the SQL + statement, wrong number of parameters specified, etc. + """ + + __module__ = "psycopg" + + +class NotSupportedError(DatabaseError): + """ + A method or database API was used which is not supported by the database. + """ + + __module__ = "psycopg" + + +class ConnectionTimeout(OperationalError): + """ + Exception raised on timeout of the `~psycopg.Connection.connect()` method. + + The error is raised if the ``connect_timeout`` is specified and a + connection is not obtained in useful time. + + Subclass of `~psycopg.OperationalError`. + """ + + +class PipelineAborted(OperationalError): + """ + Raised when a operation fails because the current pipeline is in aborted state. + + Subclass of `~psycopg.OperationalError`. + """ + + +class Diagnostic: + """Details from a database error report.""" + + def __init__(self, info: ErrorInfo, encoding: str = "utf-8"): + self._info = info + self._encoding = encoding + + @property + def severity(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY) + + @property + def severity_nonlocalized(self) -> Optional[str]: + return self._error_message(DiagnosticField.SEVERITY_NONLOCALIZED) + + @property + def sqlstate(self) -> Optional[str]: + return self._error_message(DiagnosticField.SQLSTATE) + + @property + def message_primary(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_PRIMARY) + + @property + def message_detail(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_DETAIL) + + @property + def message_hint(self) -> Optional[str]: + return self._error_message(DiagnosticField.MESSAGE_HINT) + + @property + def statement_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.STATEMENT_POSITION) + + @property + def internal_position(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_POSITION) + + @property + def internal_query(self) -> Optional[str]: + return self._error_message(DiagnosticField.INTERNAL_QUERY) + + @property + def context(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONTEXT) + + @property + def schema_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.SCHEMA_NAME) + + @property + def table_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.TABLE_NAME) + + @property + def column_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.COLUMN_NAME) + + @property + def datatype_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.DATATYPE_NAME) + + @property + def constraint_name(self) -> Optional[str]: + return self._error_message(DiagnosticField.CONSTRAINT_NAME) + + @property + def source_file(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FILE) + + @property + def source_line(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_LINE) + + @property + def source_function(self) -> Optional[str]: + return self._error_message(DiagnosticField.SOURCE_FUNCTION) + + def _error_message(self, field: DiagnosticField) -> Optional[str]: + if self._info: + if isinstance(self._info, dict): + val = self._info.get(field) + else: + val = self._info.error_field(field) + + if val is not None: + return val.decode(self._encoding, "replace") + + return None + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + res = super().__reduce__() + if isinstance(res, tuple) and len(res) >= 3: + res[2]["_info"] = _info_to_dict(self._info) + + return res + + +def _info_to_dict(info: ErrorInfo) -> ErrorInfo: + """ + Convert a PGresult to a dictionary to make the info picklable. + """ + # PGresult is a protocol, can't use isinstance + if _is_pgresult(info): + return {v: info.error_field(v) for v in DiagnosticField} + else: + return info + + +def lookup(sqlstate: str) -> Type[Error]: + """Lookup an error code or `constant name`__ and return its exception class. + + Raise `!KeyError` if the code is not found. + + .. __: https://www.postgresql.org/docs/current/errcodes-appendix.html + #ERRCODES-TABLE + """ + return _sqlcodes[sqlstate.upper()] + + +def error_from_result(result: PGresult, encoding: str = "utf-8") -> Error: + from psycopg import pq + + state = result.error_field(DiagnosticField.SQLSTATE) or b"" + cls = _class_for_state(state.decode("ascii")) + return cls( + pq.error_message(result, encoding=encoding), + info=result, + encoding=encoding, + ) + + +def _is_pgresult(info: ErrorInfo) -> TypeGuard[PGresult]: + """Return True if an ErrorInfo is a PGresult instance.""" + # PGresult is a protocol, can't use isinstance + return hasattr(info, "error_field") + + +def _class_for_state(sqlstate: str) -> Type[Error]: + try: + return lookup(sqlstate) + except KeyError: + return get_base_exception(sqlstate) + + +def get_base_exception(sqlstate: str) -> Type[Error]: + return ( + _base_exc_map.get(sqlstate[:2]) + or _base_exc_map.get(sqlstate[:1]) + or DatabaseError + ) + + +_base_exc_map = { + "08": OperationalError, # Connection Exception + "0A": NotSupportedError, # Feature Not Supported + "20": ProgrammingError, # Case Not Foud + "21": ProgrammingError, # Cardinality Violation + "22": DataError, # Data Exception + "23": IntegrityError, # Integrity Constraint Violation + "24": InternalError, # Invalid Cursor State + "25": InternalError, # Invalid Transaction State + "26": ProgrammingError, # Invalid SQL Statement Name * + "27": OperationalError, # Triggered Data Change Violation + "28": OperationalError, # Invalid Authorization Specification + "2B": InternalError, # Dependent Privilege Descriptors Still Exist + "2D": InternalError, # Invalid Transaction Termination + "2F": OperationalError, # SQL Routine Exception * + "34": ProgrammingError, # Invalid Cursor Name * + "38": OperationalError, # External Routine Exception * + "39": OperationalError, # External Routine Invocation Exception * + "3B": OperationalError, # Savepoint Exception * + "3D": ProgrammingError, # Invalid Catalog Name + "3F": ProgrammingError, # Invalid Schema Name + "40": OperationalError, # Transaction Rollback + "42": ProgrammingError, # Syntax Error or Access Rule Violation + "44": ProgrammingError, # WITH CHECK OPTION Violation + "53": OperationalError, # Insufficient Resources + "54": OperationalError, # Program Limit Exceeded + "55": OperationalError, # Object Not In Prerequisite State + "57": OperationalError, # Operator Intervention + "58": OperationalError, # System Error (errors external to PostgreSQL itself) + "F": OperationalError, # Configuration File Error + "H": OperationalError, # Foreign Data Wrapper Error (SQL/MED) + "P": ProgrammingError, # PL/pgSQL Error + "X": InternalError, # Internal Error +} + + +# Error classes generated by tools/update_errors.py + +# fmt: off +# autogenerated: start + + +# Class 02 - No Data (this is also a warning class per the SQL standard) + +class NoData(DatabaseError, + code='02000', name='NO_DATA'): + pass + +class NoAdditionalDynamicResultSetsReturned(DatabaseError, + code='02001', name='NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED'): + pass + + +# Class 03 - SQL Statement Not Yet Complete + +class SqlStatementNotYetComplete(DatabaseError, + code='03000', name='SQL_STATEMENT_NOT_YET_COMPLETE'): + pass + + +# Class 08 - Connection Exception + +class ConnectionException(OperationalError, + code='08000', name='CONNECTION_EXCEPTION'): + pass + +class SqlclientUnableToEstablishSqlconnection(OperationalError, + code='08001', name='SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION'): + pass + +class ConnectionDoesNotExist(OperationalError, + code='08003', name='CONNECTION_DOES_NOT_EXIST'): + pass + +class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError, + code='08004', name='SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION'): + pass + +class ConnectionFailure(OperationalError, + code='08006', name='CONNECTION_FAILURE'): + pass + +class TransactionResolutionUnknown(OperationalError, + code='08007', name='TRANSACTION_RESOLUTION_UNKNOWN'): + pass + +class ProtocolViolation(OperationalError, + code='08P01', name='PROTOCOL_VIOLATION'): + pass + + +# Class 09 - Triggered Action Exception + +class TriggeredActionException(DatabaseError, + code='09000', name='TRIGGERED_ACTION_EXCEPTION'): + pass + + +# Class 0A - Feature Not Supported + +class FeatureNotSupported(NotSupportedError, + code='0A000', name='FEATURE_NOT_SUPPORTED'): + pass + + +# Class 0B - Invalid Transaction Initiation + +class InvalidTransactionInitiation(DatabaseError, + code='0B000', name='INVALID_TRANSACTION_INITIATION'): + pass + + +# Class 0F - Locator Exception + +class LocatorException(DatabaseError, + code='0F000', name='LOCATOR_EXCEPTION'): + pass + +class InvalidLocatorSpecification(DatabaseError, + code='0F001', name='INVALID_LOCATOR_SPECIFICATION'): + pass + + +# Class 0L - Invalid Grantor + +class InvalidGrantor(DatabaseError, + code='0L000', name='INVALID_GRANTOR'): + pass + +class InvalidGrantOperation(DatabaseError, + code='0LP01', name='INVALID_GRANT_OPERATION'): + pass + + +# Class 0P - Invalid Role Specification + +class InvalidRoleSpecification(DatabaseError, + code='0P000', name='INVALID_ROLE_SPECIFICATION'): + pass + + +# Class 0Z - Diagnostics Exception + +class DiagnosticsException(DatabaseError, + code='0Z000', name='DIAGNOSTICS_EXCEPTION'): + pass + +class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError, + code='0Z002', name='STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER'): + pass + + +# Class 20 - Case Not Found + +class CaseNotFound(ProgrammingError, + code='20000', name='CASE_NOT_FOUND'): + pass + + +# Class 21 - Cardinality Violation + +class CardinalityViolation(ProgrammingError, + code='21000', name='CARDINALITY_VIOLATION'): + pass + + +# Class 22 - Data Exception + +class DataException(DataError, + code='22000', name='DATA_EXCEPTION'): + pass + +class StringDataRightTruncation(DataError, + code='22001', name='STRING_DATA_RIGHT_TRUNCATION'): + pass + +class NullValueNoIndicatorParameter(DataError, + code='22002', name='NULL_VALUE_NO_INDICATOR_PARAMETER'): + pass + +class NumericValueOutOfRange(DataError, + code='22003', name='NUMERIC_VALUE_OUT_OF_RANGE'): + pass + +class NullValueNotAllowed(DataError, + code='22004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class ErrorInAssignment(DataError, + code='22005', name='ERROR_IN_ASSIGNMENT'): + pass + +class InvalidDatetimeFormat(DataError, + code='22007', name='INVALID_DATETIME_FORMAT'): + pass + +class DatetimeFieldOverflow(DataError, + code='22008', name='DATETIME_FIELD_OVERFLOW'): + pass + +class InvalidTimeZoneDisplacementValue(DataError, + code='22009', name='INVALID_TIME_ZONE_DISPLACEMENT_VALUE'): + pass + +class EscapeCharacterConflict(DataError, + code='2200B', name='ESCAPE_CHARACTER_CONFLICT'): + pass + +class InvalidUseOfEscapeCharacter(DataError, + code='2200C', name='INVALID_USE_OF_ESCAPE_CHARACTER'): + pass + +class InvalidEscapeOctet(DataError, + code='2200D', name='INVALID_ESCAPE_OCTET'): + pass + +class ZeroLengthCharacterString(DataError, + code='2200F', name='ZERO_LENGTH_CHARACTER_STRING'): + pass + +class MostSpecificTypeMismatch(DataError, + code='2200G', name='MOST_SPECIFIC_TYPE_MISMATCH'): + pass + +class SequenceGeneratorLimitExceeded(DataError, + code='2200H', name='SEQUENCE_GENERATOR_LIMIT_EXCEEDED'): + pass + +class NotAnXmlDocument(DataError, + code='2200L', name='NOT_AN_XML_DOCUMENT'): + pass + +class InvalidXmlDocument(DataError, + code='2200M', name='INVALID_XML_DOCUMENT'): + pass + +class InvalidXmlContent(DataError, + code='2200N', name='INVALID_XML_CONTENT'): + pass + +class InvalidXmlComment(DataError, + code='2200S', name='INVALID_XML_COMMENT'): + pass + +class InvalidXmlProcessingInstruction(DataError, + code='2200T', name='INVALID_XML_PROCESSING_INSTRUCTION'): + pass + +class InvalidIndicatorParameterValue(DataError, + code='22010', name='INVALID_INDICATOR_PARAMETER_VALUE'): + pass + +class SubstringError(DataError, + code='22011', name='SUBSTRING_ERROR'): + pass + +class DivisionByZero(DataError, + code='22012', name='DIVISION_BY_ZERO'): + pass + +class InvalidPrecedingOrFollowingSize(DataError, + code='22013', name='INVALID_PRECEDING_OR_FOLLOWING_SIZE'): + pass + +class InvalidArgumentForNtileFunction(DataError, + code='22014', name='INVALID_ARGUMENT_FOR_NTILE_FUNCTION'): + pass + +class IntervalFieldOverflow(DataError, + code='22015', name='INTERVAL_FIELD_OVERFLOW'): + pass + +class InvalidArgumentForNthValueFunction(DataError, + code='22016', name='INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION'): + pass + +class InvalidCharacterValueForCast(DataError, + code='22018', name='INVALID_CHARACTER_VALUE_FOR_CAST'): + pass + +class InvalidEscapeCharacter(DataError, + code='22019', name='INVALID_ESCAPE_CHARACTER'): + pass + +class InvalidRegularExpression(DataError, + code='2201B', name='INVALID_REGULAR_EXPRESSION'): + pass + +class InvalidArgumentForLogarithm(DataError, + code='2201E', name='INVALID_ARGUMENT_FOR_LOGARITHM'): + pass + +class InvalidArgumentForPowerFunction(DataError, + code='2201F', name='INVALID_ARGUMENT_FOR_POWER_FUNCTION'): + pass + +class InvalidArgumentForWidthBucketFunction(DataError, + code='2201G', name='INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION'): + pass + +class InvalidRowCountInLimitClause(DataError, + code='2201W', name='INVALID_ROW_COUNT_IN_LIMIT_CLAUSE'): + pass + +class InvalidRowCountInResultOffsetClause(DataError, + code='2201X', name='INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE'): + pass + +class CharacterNotInRepertoire(DataError, + code='22021', name='CHARACTER_NOT_IN_REPERTOIRE'): + pass + +class IndicatorOverflow(DataError, + code='22022', name='INDICATOR_OVERFLOW'): + pass + +class InvalidParameterValue(DataError, + code='22023', name='INVALID_PARAMETER_VALUE'): + pass + +class UnterminatedCString(DataError, + code='22024', name='UNTERMINATED_C_STRING'): + pass + +class InvalidEscapeSequence(DataError, + code='22025', name='INVALID_ESCAPE_SEQUENCE'): + pass + +class StringDataLengthMismatch(DataError, + code='22026', name='STRING_DATA_LENGTH_MISMATCH'): + pass + +class TrimError(DataError, + code='22027', name='TRIM_ERROR'): + pass + +class ArraySubscriptError(DataError, + code='2202E', name='ARRAY_SUBSCRIPT_ERROR'): + pass + +class InvalidTablesampleRepeat(DataError, + code='2202G', name='INVALID_TABLESAMPLE_REPEAT'): + pass + +class InvalidTablesampleArgument(DataError, + code='2202H', name='INVALID_TABLESAMPLE_ARGUMENT'): + pass + +class DuplicateJsonObjectKeyValue(DataError, + code='22030', name='DUPLICATE_JSON_OBJECT_KEY_VALUE'): + pass + +class InvalidArgumentForSqlJsonDatetimeFunction(DataError, + code='22031', name='INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION'): + pass + +class InvalidJsonText(DataError, + code='22032', name='INVALID_JSON_TEXT'): + pass + +class InvalidSqlJsonSubscript(DataError, + code='22033', name='INVALID_SQL_JSON_SUBSCRIPT'): + pass + +class MoreThanOneSqlJsonItem(DataError, + code='22034', name='MORE_THAN_ONE_SQL_JSON_ITEM'): + pass + +class NoSqlJsonItem(DataError, + code='22035', name='NO_SQL_JSON_ITEM'): + pass + +class NonNumericSqlJsonItem(DataError, + code='22036', name='NON_NUMERIC_SQL_JSON_ITEM'): + pass + +class NonUniqueKeysInAJsonObject(DataError, + code='22037', name='NON_UNIQUE_KEYS_IN_A_JSON_OBJECT'): + pass + +class SingletonSqlJsonItemRequired(DataError, + code='22038', name='SINGLETON_SQL_JSON_ITEM_REQUIRED'): + pass + +class SqlJsonArrayNotFound(DataError, + code='22039', name='SQL_JSON_ARRAY_NOT_FOUND'): + pass + +class SqlJsonMemberNotFound(DataError, + code='2203A', name='SQL_JSON_MEMBER_NOT_FOUND'): + pass + +class SqlJsonNumberNotFound(DataError, + code='2203B', name='SQL_JSON_NUMBER_NOT_FOUND'): + pass + +class SqlJsonObjectNotFound(DataError, + code='2203C', name='SQL_JSON_OBJECT_NOT_FOUND'): + pass + +class TooManyJsonArrayElements(DataError, + code='2203D', name='TOO_MANY_JSON_ARRAY_ELEMENTS'): + pass + +class TooManyJsonObjectMembers(DataError, + code='2203E', name='TOO_MANY_JSON_OBJECT_MEMBERS'): + pass + +class SqlJsonScalarRequired(DataError, + code='2203F', name='SQL_JSON_SCALAR_REQUIRED'): + pass + +class SqlJsonItemCannotBeCastToTargetType(DataError, + code='2203G', name='SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE'): + pass + +class FloatingPointException(DataError, + code='22P01', name='FLOATING_POINT_EXCEPTION'): + pass + +class InvalidTextRepresentation(DataError, + code='22P02', name='INVALID_TEXT_REPRESENTATION'): + pass + +class InvalidBinaryRepresentation(DataError, + code='22P03', name='INVALID_BINARY_REPRESENTATION'): + pass + +class BadCopyFileFormat(DataError, + code='22P04', name='BAD_COPY_FILE_FORMAT'): + pass + +class UntranslatableCharacter(DataError, + code='22P05', name='UNTRANSLATABLE_CHARACTER'): + pass + +class NonstandardUseOfEscapeCharacter(DataError, + code='22P06', name='NONSTANDARD_USE_OF_ESCAPE_CHARACTER'): + pass + + +# Class 23 - Integrity Constraint Violation + +class IntegrityConstraintViolation(IntegrityError, + code='23000', name='INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class RestrictViolation(IntegrityError, + code='23001', name='RESTRICT_VIOLATION'): + pass + +class NotNullViolation(IntegrityError, + code='23502', name='NOT_NULL_VIOLATION'): + pass + +class ForeignKeyViolation(IntegrityError, + code='23503', name='FOREIGN_KEY_VIOLATION'): + pass + +class UniqueViolation(IntegrityError, + code='23505', name='UNIQUE_VIOLATION'): + pass + +class CheckViolation(IntegrityError, + code='23514', name='CHECK_VIOLATION'): + pass + +class ExclusionViolation(IntegrityError, + code='23P01', name='EXCLUSION_VIOLATION'): + pass + + +# Class 24 - Invalid Cursor State + +class InvalidCursorState(InternalError, + code='24000', name='INVALID_CURSOR_STATE'): + pass + + +# Class 25 - Invalid Transaction State + +class InvalidTransactionState(InternalError, + code='25000', name='INVALID_TRANSACTION_STATE'): + pass + +class ActiveSqlTransaction(InternalError, + code='25001', name='ACTIVE_SQL_TRANSACTION'): + pass + +class BranchTransactionAlreadyActive(InternalError, + code='25002', name='BRANCH_TRANSACTION_ALREADY_ACTIVE'): + pass + +class InappropriateAccessModeForBranchTransaction(InternalError, + code='25003', name='INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION'): + pass + +class InappropriateIsolationLevelForBranchTransaction(InternalError, + code='25004', name='INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION'): + pass + +class NoActiveSqlTransactionForBranchTransaction(InternalError, + code='25005', name='NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION'): + pass + +class ReadOnlySqlTransaction(InternalError, + code='25006', name='READ_ONLY_SQL_TRANSACTION'): + pass + +class SchemaAndDataStatementMixingNotSupported(InternalError, + code='25007', name='SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED'): + pass + +class HeldCursorRequiresSameIsolationLevel(InternalError, + code='25008', name='HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL'): + pass + +class NoActiveSqlTransaction(InternalError, + code='25P01', name='NO_ACTIVE_SQL_TRANSACTION'): + pass + +class InFailedSqlTransaction(InternalError, + code='25P02', name='IN_FAILED_SQL_TRANSACTION'): + pass + +class IdleInTransactionSessionTimeout(InternalError, + code='25P03', name='IDLE_IN_TRANSACTION_SESSION_TIMEOUT'): + pass + + +# Class 26 - Invalid SQL Statement Name + +class InvalidSqlStatementName(ProgrammingError, + code='26000', name='INVALID_SQL_STATEMENT_NAME'): + pass + + +# Class 27 - Triggered Data Change Violation + +class TriggeredDataChangeViolation(OperationalError, + code='27000', name='TRIGGERED_DATA_CHANGE_VIOLATION'): + pass + + +# Class 28 - Invalid Authorization Specification + +class InvalidAuthorizationSpecification(OperationalError, + code='28000', name='INVALID_AUTHORIZATION_SPECIFICATION'): + pass + +class InvalidPassword(OperationalError, + code='28P01', name='INVALID_PASSWORD'): + pass + + +# Class 2B - Dependent Privilege Descriptors Still Exist + +class DependentPrivilegeDescriptorsStillExist(InternalError, + code='2B000', name='DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST'): + pass + +class DependentObjectsStillExist(InternalError, + code='2BP01', name='DEPENDENT_OBJECTS_STILL_EXIST'): + pass + + +# Class 2D - Invalid Transaction Termination + +class InvalidTransactionTermination(InternalError, + code='2D000', name='INVALID_TRANSACTION_TERMINATION'): + pass + + +# Class 2F - SQL Routine Exception + +class SqlRoutineException(OperationalError, + code='2F000', name='SQL_ROUTINE_EXCEPTION'): + pass + +class ModifyingSqlDataNotPermitted(OperationalError, + code='2F002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttempted(OperationalError, + code='2F003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermitted(OperationalError, + code='2F004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + +class FunctionExecutedNoReturnStatement(OperationalError, + code='2F005', name='FUNCTION_EXECUTED_NO_RETURN_STATEMENT'): + pass + + +# Class 34 - Invalid Cursor Name + +class InvalidCursorName(ProgrammingError, + code='34000', name='INVALID_CURSOR_NAME'): + pass + + +# Class 38 - External Routine Exception + +class ExternalRoutineException(OperationalError, + code='38000', name='EXTERNAL_ROUTINE_EXCEPTION'): + pass + +class ContainingSqlNotPermitted(OperationalError, + code='38001', name='CONTAINING_SQL_NOT_PERMITTED'): + pass + +class ModifyingSqlDataNotPermittedExt(OperationalError, + code='38002', name='MODIFYING_SQL_DATA_NOT_PERMITTED'): + pass + +class ProhibitedSqlStatementAttemptedExt(OperationalError, + code='38003', name='PROHIBITED_SQL_STATEMENT_ATTEMPTED'): + pass + +class ReadingSqlDataNotPermittedExt(OperationalError, + code='38004', name='READING_SQL_DATA_NOT_PERMITTED'): + pass + + +# Class 39 - External Routine Invocation Exception + +class ExternalRoutineInvocationException(OperationalError, + code='39000', name='EXTERNAL_ROUTINE_INVOCATION_EXCEPTION'): + pass + +class InvalidSqlstateReturned(OperationalError, + code='39001', name='INVALID_SQLSTATE_RETURNED'): + pass + +class NullValueNotAllowedExt(OperationalError, + code='39004', name='NULL_VALUE_NOT_ALLOWED'): + pass + +class TriggerProtocolViolated(OperationalError, + code='39P01', name='TRIGGER_PROTOCOL_VIOLATED'): + pass + +class SrfProtocolViolated(OperationalError, + code='39P02', name='SRF_PROTOCOL_VIOLATED'): + pass + +class EventTriggerProtocolViolated(OperationalError, + code='39P03', name='EVENT_TRIGGER_PROTOCOL_VIOLATED'): + pass + + +# Class 3B - Savepoint Exception + +class SavepointException(OperationalError, + code='3B000', name='SAVEPOINT_EXCEPTION'): + pass + +class InvalidSavepointSpecification(OperationalError, + code='3B001', name='INVALID_SAVEPOINT_SPECIFICATION'): + pass + + +# Class 3D - Invalid Catalog Name + +class InvalidCatalogName(ProgrammingError, + code='3D000', name='INVALID_CATALOG_NAME'): + pass + + +# Class 3F - Invalid Schema Name + +class InvalidSchemaName(ProgrammingError, + code='3F000', name='INVALID_SCHEMA_NAME'): + pass + + +# Class 40 - Transaction Rollback + +class TransactionRollback(OperationalError, + code='40000', name='TRANSACTION_ROLLBACK'): + pass + +class SerializationFailure(OperationalError, + code='40001', name='SERIALIZATION_FAILURE'): + pass + +class TransactionIntegrityConstraintViolation(OperationalError, + code='40002', name='TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION'): + pass + +class StatementCompletionUnknown(OperationalError, + code='40003', name='STATEMENT_COMPLETION_UNKNOWN'): + pass + +class DeadlockDetected(OperationalError, + code='40P01', name='DEADLOCK_DETECTED'): + pass + + +# Class 42 - Syntax Error or Access Rule Violation + +class SyntaxErrorOrAccessRuleViolation(ProgrammingError, + code='42000', name='SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION'): + pass + +class InsufficientPrivilege(ProgrammingError, + code='42501', name='INSUFFICIENT_PRIVILEGE'): + pass + +class SyntaxError(ProgrammingError, + code='42601', name='SYNTAX_ERROR'): + pass + +class InvalidName(ProgrammingError, + code='42602', name='INVALID_NAME'): + pass + +class InvalidColumnDefinition(ProgrammingError, + code='42611', name='INVALID_COLUMN_DEFINITION'): + pass + +class NameTooLong(ProgrammingError, + code='42622', name='NAME_TOO_LONG'): + pass + +class DuplicateColumn(ProgrammingError, + code='42701', name='DUPLICATE_COLUMN'): + pass + +class AmbiguousColumn(ProgrammingError, + code='42702', name='AMBIGUOUS_COLUMN'): + pass + +class UndefinedColumn(ProgrammingError, + code='42703', name='UNDEFINED_COLUMN'): + pass + +class UndefinedObject(ProgrammingError, + code='42704', name='UNDEFINED_OBJECT'): + pass + +class DuplicateObject(ProgrammingError, + code='42710', name='DUPLICATE_OBJECT'): + pass + +class DuplicateAlias(ProgrammingError, + code='42712', name='DUPLICATE_ALIAS'): + pass + +class DuplicateFunction(ProgrammingError, + code='42723', name='DUPLICATE_FUNCTION'): + pass + +class AmbiguousFunction(ProgrammingError, + code='42725', name='AMBIGUOUS_FUNCTION'): + pass + +class GroupingError(ProgrammingError, + code='42803', name='GROUPING_ERROR'): + pass + +class DatatypeMismatch(ProgrammingError, + code='42804', name='DATATYPE_MISMATCH'): + pass + +class WrongObjectType(ProgrammingError, + code='42809', name='WRONG_OBJECT_TYPE'): + pass + +class InvalidForeignKey(ProgrammingError, + code='42830', name='INVALID_FOREIGN_KEY'): + pass + +class CannotCoerce(ProgrammingError, + code='42846', name='CANNOT_COERCE'): + pass + +class UndefinedFunction(ProgrammingError, + code='42883', name='UNDEFINED_FUNCTION'): + pass + +class GeneratedAlways(ProgrammingError, + code='428C9', name='GENERATED_ALWAYS'): + pass + +class ReservedName(ProgrammingError, + code='42939', name='RESERVED_NAME'): + pass + +class UndefinedTable(ProgrammingError, + code='42P01', name='UNDEFINED_TABLE'): + pass + +class UndefinedParameter(ProgrammingError, + code='42P02', name='UNDEFINED_PARAMETER'): + pass + +class DuplicateCursor(ProgrammingError, + code='42P03', name='DUPLICATE_CURSOR'): + pass + +class DuplicateDatabase(ProgrammingError, + code='42P04', name='DUPLICATE_DATABASE'): + pass + +class DuplicatePreparedStatement(ProgrammingError, + code='42P05', name='DUPLICATE_PREPARED_STATEMENT'): + pass + +class DuplicateSchema(ProgrammingError, + code='42P06', name='DUPLICATE_SCHEMA'): + pass + +class DuplicateTable(ProgrammingError, + code='42P07', name='DUPLICATE_TABLE'): + pass + +class AmbiguousParameter(ProgrammingError, + code='42P08', name='AMBIGUOUS_PARAMETER'): + pass + +class AmbiguousAlias(ProgrammingError, + code='42P09', name='AMBIGUOUS_ALIAS'): + pass + +class InvalidColumnReference(ProgrammingError, + code='42P10', name='INVALID_COLUMN_REFERENCE'): + pass + +class InvalidCursorDefinition(ProgrammingError, + code='42P11', name='INVALID_CURSOR_DEFINITION'): + pass + +class InvalidDatabaseDefinition(ProgrammingError, + code='42P12', name='INVALID_DATABASE_DEFINITION'): + pass + +class InvalidFunctionDefinition(ProgrammingError, + code='42P13', name='INVALID_FUNCTION_DEFINITION'): + pass + +class InvalidPreparedStatementDefinition(ProgrammingError, + code='42P14', name='INVALID_PREPARED_STATEMENT_DEFINITION'): + pass + +class InvalidSchemaDefinition(ProgrammingError, + code='42P15', name='INVALID_SCHEMA_DEFINITION'): + pass + +class InvalidTableDefinition(ProgrammingError, + code='42P16', name='INVALID_TABLE_DEFINITION'): + pass + +class InvalidObjectDefinition(ProgrammingError, + code='42P17', name='INVALID_OBJECT_DEFINITION'): + pass + +class IndeterminateDatatype(ProgrammingError, + code='42P18', name='INDETERMINATE_DATATYPE'): + pass + +class InvalidRecursion(ProgrammingError, + code='42P19', name='INVALID_RECURSION'): + pass + +class WindowingError(ProgrammingError, + code='42P20', name='WINDOWING_ERROR'): + pass + +class CollationMismatch(ProgrammingError, + code='42P21', name='COLLATION_MISMATCH'): + pass + +class IndeterminateCollation(ProgrammingError, + code='42P22', name='INDETERMINATE_COLLATION'): + pass + + +# Class 44 - WITH CHECK OPTION Violation + +class WithCheckOptionViolation(ProgrammingError, + code='44000', name='WITH_CHECK_OPTION_VIOLATION'): + pass + + +# Class 53 - Insufficient Resources + +class InsufficientResources(OperationalError, + code='53000', name='INSUFFICIENT_RESOURCES'): + pass + +class DiskFull(OperationalError, + code='53100', name='DISK_FULL'): + pass + +class OutOfMemory(OperationalError, + code='53200', name='OUT_OF_MEMORY'): + pass + +class TooManyConnections(OperationalError, + code='53300', name='TOO_MANY_CONNECTIONS'): + pass + +class ConfigurationLimitExceeded(OperationalError, + code='53400', name='CONFIGURATION_LIMIT_EXCEEDED'): + pass + + +# Class 54 - Program Limit Exceeded + +class ProgramLimitExceeded(OperationalError, + code='54000', name='PROGRAM_LIMIT_EXCEEDED'): + pass + +class StatementTooComplex(OperationalError, + code='54001', name='STATEMENT_TOO_COMPLEX'): + pass + +class TooManyColumns(OperationalError, + code='54011', name='TOO_MANY_COLUMNS'): + pass + +class TooManyArguments(OperationalError, + code='54023', name='TOO_MANY_ARGUMENTS'): + pass + + +# Class 55 - Object Not In Prerequisite State + +class ObjectNotInPrerequisiteState(OperationalError, + code='55000', name='OBJECT_NOT_IN_PREREQUISITE_STATE'): + pass + +class ObjectInUse(OperationalError, + code='55006', name='OBJECT_IN_USE'): + pass + +class CantChangeRuntimeParam(OperationalError, + code='55P02', name='CANT_CHANGE_RUNTIME_PARAM'): + pass + +class LockNotAvailable(OperationalError, + code='55P03', name='LOCK_NOT_AVAILABLE'): + pass + +class UnsafeNewEnumValueUsage(OperationalError, + code='55P04', name='UNSAFE_NEW_ENUM_VALUE_USAGE'): + pass + + +# Class 57 - Operator Intervention + +class OperatorIntervention(OperationalError, + code='57000', name='OPERATOR_INTERVENTION'): + pass + +class QueryCanceled(OperationalError, + code='57014', name='QUERY_CANCELED'): + pass + +class AdminShutdown(OperationalError, + code='57P01', name='ADMIN_SHUTDOWN'): + pass + +class CrashShutdown(OperationalError, + code='57P02', name='CRASH_SHUTDOWN'): + pass + +class CannotConnectNow(OperationalError, + code='57P03', name='CANNOT_CONNECT_NOW'): + pass + +class DatabaseDropped(OperationalError, + code='57P04', name='DATABASE_DROPPED'): + pass + +class IdleSessionTimeout(OperationalError, + code='57P05', name='IDLE_SESSION_TIMEOUT'): + pass + + +# Class 58 - System Error (errors external to PostgreSQL itself) + +class SystemError(OperationalError, + code='58000', name='SYSTEM_ERROR'): + pass + +class IoError(OperationalError, + code='58030', name='IO_ERROR'): + pass + +class UndefinedFile(OperationalError, + code='58P01', name='UNDEFINED_FILE'): + pass + +class DuplicateFile(OperationalError, + code='58P02', name='DUPLICATE_FILE'): + pass + + +# Class 72 - Snapshot Failure + +class SnapshotTooOld(DatabaseError, + code='72000', name='SNAPSHOT_TOO_OLD'): + pass + + +# Class F0 - Configuration File Error + +class ConfigFileError(OperationalError, + code='F0000', name='CONFIG_FILE_ERROR'): + pass + +class LockFileExists(OperationalError, + code='F0001', name='LOCK_FILE_EXISTS'): + pass + + +# Class HV - Foreign Data Wrapper Error (SQL/MED) + +class FdwError(OperationalError, + code='HV000', name='FDW_ERROR'): + pass + +class FdwOutOfMemory(OperationalError, + code='HV001', name='FDW_OUT_OF_MEMORY'): + pass + +class FdwDynamicParameterValueNeeded(OperationalError, + code='HV002', name='FDW_DYNAMIC_PARAMETER_VALUE_NEEDED'): + pass + +class FdwInvalidDataType(OperationalError, + code='HV004', name='FDW_INVALID_DATA_TYPE'): + pass + +class FdwColumnNameNotFound(OperationalError, + code='HV005', name='FDW_COLUMN_NAME_NOT_FOUND'): + pass + +class FdwInvalidDataTypeDescriptors(OperationalError, + code='HV006', name='FDW_INVALID_DATA_TYPE_DESCRIPTORS'): + pass + +class FdwInvalidColumnName(OperationalError, + code='HV007', name='FDW_INVALID_COLUMN_NAME'): + pass + +class FdwInvalidColumnNumber(OperationalError, + code='HV008', name='FDW_INVALID_COLUMN_NUMBER'): + pass + +class FdwInvalidUseOfNullPointer(OperationalError, + code='HV009', name='FDW_INVALID_USE_OF_NULL_POINTER'): + pass + +class FdwInvalidStringFormat(OperationalError, + code='HV00A', name='FDW_INVALID_STRING_FORMAT'): + pass + +class FdwInvalidHandle(OperationalError, + code='HV00B', name='FDW_INVALID_HANDLE'): + pass + +class FdwInvalidOptionIndex(OperationalError, + code='HV00C', name='FDW_INVALID_OPTION_INDEX'): + pass + +class FdwInvalidOptionName(OperationalError, + code='HV00D', name='FDW_INVALID_OPTION_NAME'): + pass + +class FdwOptionNameNotFound(OperationalError, + code='HV00J', name='FDW_OPTION_NAME_NOT_FOUND'): + pass + +class FdwReplyHandle(OperationalError, + code='HV00K', name='FDW_REPLY_HANDLE'): + pass + +class FdwUnableToCreateExecution(OperationalError, + code='HV00L', name='FDW_UNABLE_TO_CREATE_EXECUTION'): + pass + +class FdwUnableToCreateReply(OperationalError, + code='HV00M', name='FDW_UNABLE_TO_CREATE_REPLY'): + pass + +class FdwUnableToEstablishConnection(OperationalError, + code='HV00N', name='FDW_UNABLE_TO_ESTABLISH_CONNECTION'): + pass + +class FdwNoSchemas(OperationalError, + code='HV00P', name='FDW_NO_SCHEMAS'): + pass + +class FdwSchemaNotFound(OperationalError, + code='HV00Q', name='FDW_SCHEMA_NOT_FOUND'): + pass + +class FdwTableNotFound(OperationalError, + code='HV00R', name='FDW_TABLE_NOT_FOUND'): + pass + +class FdwFunctionSequenceError(OperationalError, + code='HV010', name='FDW_FUNCTION_SEQUENCE_ERROR'): + pass + +class FdwTooManyHandles(OperationalError, + code='HV014', name='FDW_TOO_MANY_HANDLES'): + pass + +class FdwInconsistentDescriptorInformation(OperationalError, + code='HV021', name='FDW_INCONSISTENT_DESCRIPTOR_INFORMATION'): + pass + +class FdwInvalidAttributeValue(OperationalError, + code='HV024', name='FDW_INVALID_ATTRIBUTE_VALUE'): + pass + +class FdwInvalidStringLengthOrBufferLength(OperationalError, + code='HV090', name='FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH'): + pass + +class FdwInvalidDescriptorFieldIdentifier(OperationalError, + code='HV091', name='FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER'): + pass + + +# Class P0 - PL/pgSQL Error + +class PlpgsqlError(ProgrammingError, + code='P0000', name='PLPGSQL_ERROR'): + pass + +class RaiseException(ProgrammingError, + code='P0001', name='RAISE_EXCEPTION'): + pass + +class NoDataFound(ProgrammingError, + code='P0002', name='NO_DATA_FOUND'): + pass + +class TooManyRows(ProgrammingError, + code='P0003', name='TOO_MANY_ROWS'): + pass + +class AssertFailure(ProgrammingError, + code='P0004', name='ASSERT_FAILURE'): + pass + + +# Class XX - Internal Error + +class InternalError_(InternalError, + code='XX000', name='INTERNAL_ERROR'): + pass + +class DataCorrupted(InternalError, + code='XX001', name='DATA_CORRUPTED'): + pass + +class IndexCorrupted(InternalError, + code='XX002', name='INDEX_CORRUPTED'): + pass + + +# autogenerated: end +# fmt: on diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py new file mode 100644 index 0000000..584fe47 --- /dev/null +++ b/psycopg/psycopg/generators.py @@ -0,0 +1,320 @@ +""" +Generators implementing communication protocols with the libpq + +Certain operations (connection, querying) are an interleave of libpq calls and +waiting for the socket to be ready. This module contains the code to execute +the operations, yielding a polling state whenever there is to wait. The +functions in the `waiting` module are the ones who wait more or less +cooperatively for the socket to be ready and make these generators continue. + +All these generators yield pairs (fileno, `Wait`) whenever an operation would +block. The generator can be restarted sending the appropriate `Ready` state +when the file descriptor is ready. + +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import List, Optional, Union + +from . import pq +from . import errors as e +from .abc import Buffer, PipelineCommand, PQGen, PQGenConn +from .pq.abc import PGconn, PGresult +from .waiting import Wait, Ready +from ._compat import Deque +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding, conninfo_encoding + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +POLL_OK = pq.PollingStatus.OK +POLL_READING = pq.PollingStatus.READING +POLL_WRITING = pq.PollingStatus.WRITING +POLL_FAILED = pq.PollingStatus.FAILED + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +COPY_OUT = pq.ExecStatus.COPY_OUT +COPY_IN = pq.ExecStatus.COPY_IN +COPY_BOTH = pq.ExecStatus.COPY_BOTH +PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + +logger = logging.getLogger(__name__) + + +def _connect(conninfo: str) -> PQGenConn[PGconn]: + """ + Generator to create a database connection without blocking. + + """ + conn = pq.PGconn.connect_start(conninfo.encode()) + while True: + if conn.status == BAD: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection is bad: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + + status = conn.connect_poll() + if status == POLL_OK: + break + elif status == POLL_READING: + yield conn.socket, WAIT_R + elif status == POLL_WRITING: + yield conn.socket, WAIT_W + elif status == POLL_FAILED: + encoding = conninfo_encoding(conninfo) + raise e.OperationalError( + f"connection failed: {pq.error_message(conn, encoding=encoding)}", + pgconn=conn, + ) + else: + raise e.InternalError(f"unexpected poll status: {status}", pgconn=conn) + + conn.nonblocking = 1 + return conn + + +def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator sending a query and returning results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Return the list of results returned by the database (whether success + or error). + """ + yield from _send(pgconn) + rv = yield from _fetch_many(pgconn) + return rv + + +def _send(pgconn: PGconn) -> PQGen[None]: + """ + Generator to send a query to the server without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + After this generator has finished you may want to cycle using `fetch()` + to retrieve the results available. + """ + while True: + f = pgconn.flush() + if f == 0: + break + + ready = yield WAIT_RW + if ready & READY_R: + # This call may read notifies: they will be saved in the + # PGconn buffer and passed to Python later, in `fetch()`. + pgconn.consume_input() + + +def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]: + """ + Generator retrieving results from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return the list of results returned by the database (whether success + or error). + """ + results: List[PGresult] = [] + while True: + res = yield from _fetch(pgconn) + if not res: + break + + results.append(res) + status = res.status + if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + # After entering copy mode the libpq will create a phony result + # for every request so let's break the endless loop. + break + + if status == PIPELINE_SYNC: + # PIPELINE_SYNC is not followed by a NULL, but we return it alone + # similarly to other result sets. + assert len(results) == 1, results + break + + return results + + +def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]: + """ + Generator retrieving a single result from the database without blocking. + + The query must have already been sent to the server, so pgconn.flush() has + already returned 0. + + Return a result from the database (whether success or error). + """ + if pgconn.is_busy(): + yield WAIT_R + while True: + pgconn.consume_input() + if not pgconn.is_busy(): + break + yield WAIT_R + + _consume_notifies(pgconn) + + return pgconn.get_result() + + +def _pipeline_communicate( + pgconn: PGconn, commands: Deque[PipelineCommand] +) -> PQGen[List[List[PGresult]]]: + """Generator to send queries from a connection in pipeline mode while also + receiving results. + + Return a list results, including single PIPELINE_SYNC elements. + """ + results = [] + + while True: + ready = yield WAIT_RW + + if ready & READY_R: + pgconn.consume_input() + _consume_notifies(pgconn) + + res: List[PGresult] = [] + while not pgconn.is_busy(): + r = pgconn.get_result() + if r is None: + if not res: + break + results.append(res) + res = [] + elif r.status == PIPELINE_SYNC: + assert not res + results.append([r]) + else: + res.append(r) + + if ready & READY_W: + pgconn.flush() + if not commands: + break + commands.popleft()() + + return results + + +def _consume_notifies(pgconn: PGconn) -> None: + # Consume notifies + while True: + n = pgconn.notifies() + if not n: + break + if pgconn.notify_handler: + pgconn.notify_handler(n) + + +def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: + yield WAIT_R + pgconn.consume_input() + + ns = [] + while True: + n = pgconn.notifies() + if n: + ns.append(n) + else: + break + + return ns + + +def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]: + while True: + nbytes, data = pgconn.get_copy_data(1) + if nbytes != 0: + break + + # would block + yield WAIT_R + pgconn.consume_input() + + if nbytes > 0: + # some data + return data + + # Retrieve the final result of copy + results = yield from _fetch_many(pgconn) + if len(results) > 1: + # TODO: too brutal? Copy worked. + raise e.ProgrammingError("you cannot mix COPY with other operations") + result = results[0] + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]: + # Retry enqueuing data until successful. + # + # WARNING! This can cause an infinite loop if the buffer is too large. (see + # ticket #255). We avoid it in the Copy object by splitting a large buffer + # into smaller ones. We prefer to do it there instead of here in order to + # do it upstream the queue decoupling the writer task from the producer one. + while pgconn.put_copy_data(buffer) == 0: + yield WAIT_W + + +def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: + # Retry enqueuing end copy message until successful + while pgconn.put_copy_end(error) == 0: + yield WAIT_W + + # Repeat until it the message is flushed to the server + while True: + yield WAIT_W + f = pgconn.flush() + if f == 0: + break + + # Retrieve the final result of copy + (result,) = yield from _fetch_many(pgconn) + if result.status != COMMAND_OK: + encoding = pgconn_encoding(pgconn) + raise e.error_from_result(result, encoding=encoding) + + return result + + +# Override functions with fast versions if available +if _psycopg: + connect = _psycopg.connect + execute = _psycopg.execute + send = _psycopg.send + fetch_many = _psycopg.fetch_many + fetch = _psycopg.fetch + pipeline_communicate = _psycopg.pipeline_communicate + +else: + connect = _connect + execute = _execute + send = _send + fetch_many = _fetch_many + fetch = _fetch + pipeline_communicate = _pipeline_communicate diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py new file mode 100644 index 0000000..792a9c8 --- /dev/null +++ b/psycopg/psycopg/postgres.py @@ -0,0 +1,125 @@ +""" +Types configuration specific to PostgreSQL. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry +from .abc import AdaptContext +from ._adapters_map import AdaptersMap + +# Global objects with PostgreSQL builtins and globally registered user types. +types = TypesRegistry() + +# Global adapter maps with PostgreSQL types configuration +adapters = AdaptersMap(types=types) + +# Use tools/update_oids.py to update this data. +for t in [ + TypeInfo('"char"', 18, 1002), + # autogenerated: start + # Generated from PostgreSQL 15.0 + TypeInfo("aclitem", 1033, 1034), + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, regtype="boolean"), + TypeInfo("box", 603, 1020, delimiter=";"), + TypeInfo("bpchar", 1042, 1014, regtype="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("cid", 29, 1012), + TypeInfo("cidr", 650, 651), + TypeInfo("circle", 718, 719), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, regtype="real"), + TypeInfo("float8", 701, 1022, regtype="double precision"), + TypeInfo("gtsvector", 3642, 3644), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, regtype="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007, regtype="integer"), + TypeInfo("int8", 20, 1016, regtype="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("json", 114, 199), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("jsonpath", 4072, 4073), + TypeInfo("line", 628, 629), + TypeInfo("lseg", 601, 1018), + TypeInfo("macaddr", 829, 1040), + TypeInfo("macaddr8", 774, 775), + TypeInfo("money", 790, 791), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("path", 602, 1019), + TypeInfo("pg_lsn", 3220, 3221), + TypeInfo("point", 600, 1017), + TypeInfo("polygon", 604, 1027), + TypeInfo("record", 2249, 2287), + TypeInfo("refcursor", 1790, 2201), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regcollation", 4191, 4192), + TypeInfo("regconfig", 3734, 3735), + TypeInfo("regdictionary", 3769, 3770), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regoper", 2203, 2208), + TypeInfo("regoperator", 2204, 2209), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("tid", 27, 1010), + TypeInfo("time", 1083, 1183, regtype="time without time zone"), + TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, regtype="time with time zone"), + TypeInfo("tsquery", 3615, 3645), + TypeInfo("tsvector", 3614, 3643), + TypeInfo("txid_snapshot", 2970, 2949), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, regtype="bit varying"), + TypeInfo("varchar", 1043, 1015, regtype="character varying"), + TypeInfo("xid", 28, 1011), + TypeInfo("xid8", 5069, 271), + TypeInfo("xml", 142, 143), + RangeInfo("daterange", 3912, 3913, subtype_oid=1082), + RangeInfo("int4range", 3904, 3905, subtype_oid=23), + RangeInfo("int8range", 3926, 3927, subtype_oid=20), + RangeInfo("numrange", 3906, 3907, subtype_oid=1700), + RangeInfo("tsrange", 3908, 3909, subtype_oid=1114), + RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184), + MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082), + MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23), + MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20), + MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700), + MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114), + MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184), + # autogenerated: end +]: + types.add(t) + + +# A few oids used a bit everywhere +INVALID_OID = 0 +TEXT_OID = types["text"].oid +TEXT_ARRAY_OID = types["text"].array_oid + + +def register_default_adapters(context: AdaptContext) -> None: + + from .types import array, bool, composite, datetime, enum, json, multirange + from .types import net, none, numeric, range, string, uuid + + array.register_default_adapters(context) + bool.register_default_adapters(context) + composite.register_default_adapters(context) + datetime.register_default_adapters(context) + enum.register_default_adapters(context) + json.register_default_adapters(context) + multirange.register_default_adapters(context) + net.register_default_adapters(context) + none.register_default_adapters(context) + numeric.register_default_adapters(context) + range.register_default_adapters(context) + string.register_default_adapters(context) + uuid.register_default_adapters(context) diff --git a/psycopg/psycopg/pq/__init__.py b/psycopg/psycopg/pq/__init__.py new file mode 100644 index 0000000..d5180b1 --- /dev/null +++ b/psycopg/psycopg/pq/__init__.py @@ -0,0 +1,133 @@ +""" +psycopg libpq wrapper + +This package exposes the libpq functionalities as Python objects and functions. + +The real implementation (the binding to the C library) is +implementation-dependant but all the implementations share the same interface. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import logging +from typing import Callable, List, Type + +from . import abc +from .misc import ConninfoOption, PGnotify, PGresAttDesc +from .misc import error_message +from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace +from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus + +logger = logging.getLogger(__name__) + +__impl__: str +"""The currently loaded implementation of the `!psycopg.pq` package. + +Possible values include ``python``, ``c``, ``binary``. +""" + +__build_version__: int +"""The libpq version the C package was built with. + +A number in the same format of `~psycopg.ConnectionInfo.server_version` +representing the libpq used to build the speedup module (``c``, ``binary``) if +available. + +Certain features might not be available if the built version is too old. +""" + +version: Callable[[], int] +PGconn: Type[abc.PGconn] +PGresult: Type[abc.PGresult] +Conninfo: Type[abc.Conninfo] +Escaping: Type[abc.Escaping] +PGcancel: Type[abc.PGcancel] + + +def import_from_libpq() -> None: + """ + Import pq objects implementation from the best libpq wrapper available. + + If an implementation is requested try to import only it, otherwise + try to import the best implementation available. + """ + # import these names into the module on success as side effect + global __impl__, version, __build_version__ + global PGconn, PGresult, Conninfo, Escaping, PGcancel + + impl = os.environ.get("PSYCOPG_IMPL", "").lower() + module = None + attempts: List[str] = [] + + def handle_error(name: str, e: Exception) -> None: + if not impl: + msg = f"couldn't import psycopg '{name}' implementation: {e}" + logger.debug(msg) + attempts.append(msg) + else: + msg = f"couldn't import requested psycopg '{name}' implementation: {e}" + raise ImportError(msg) from e + + # The best implementation: fast but requires the system libpq installed + if not impl or impl == "c": + try: + from psycopg_c import pq as module # type: ignore + except Exception as e: + handle_error("c", e) + + # Second best implementation: fast and stand-alone + if not module and (not impl or impl == "binary"): + try: + from psycopg_binary import pq as module # type: ignore + except Exception as e: + handle_error("binary", e) + + # Pure Python implementation, slow and requires the system libpq installed. + if not module and (not impl or impl == "python"): + try: + from . import pq_ctypes as module # type: ignore[no-redef] + except Exception as e: + handle_error("python", e) + + if module: + __impl__ = module.__impl__ + version = module.version + PGconn = module.PGconn + PGresult = module.PGresult + Conninfo = module.Conninfo + Escaping = module.Escaping + PGcancel = module.PGcancel + __build_version__ = module.__build_version__ + elif impl: + raise ImportError(f"requested psycopg implementation '{impl}' unknown") + else: + sattempts = "\n".join(f"- {attempt}" for attempt in attempts) + raise ImportError( + f"""\ +no pq wrapper available. +Attempts made: +{sattempts}""" + ) + + +import_from_libpq() + +__all__ = ( + "ConnStatus", + "PipelineStatus", + "PollingStatus", + "TransactionStatus", + "ExecStatus", + "Ping", + "DiagnosticField", + "Format", + "Trace", + "PGconn", + "PGnotify", + "Conninfo", + "PGresAttDesc", + "error_message", + "ConninfoOption", + "version", +) diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py new file mode 100644 index 0000000..f35d09f --- /dev/null +++ b/psycopg/psycopg/pq/_debug.py @@ -0,0 +1,106 @@ +""" +libpq debugging tools + +These functionalities are exposed here for convenience, but are not part of +the public interface and are subject to change at any moment. + +Suggested usage:: + + import logging + import psycopg + from psycopg import pq + from psycopg.pq._debug import PGconnDebug + + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger = logging.getLogger("psycopg.debug") + logger.setLevel(logging.INFO) + + assert pq.__impl__ == "python" + pq.PGconn = PGconnDebug + + with psycopg.connect("") as conn: + conn.pgconn.trace(2) + conn.pgconn.set_trace_flags( + pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE) + ... + +""" + +# Copyright (C) 2022 The Psycopg Team + +import inspect +import logging +from typing import Any, Callable, Type, TypeVar, TYPE_CHECKING +from functools import wraps + +from . import PGconn +from .misc import connection_summary + +if TYPE_CHECKING: + from . import abc + +Func = TypeVar("Func", bound=Callable[..., Any]) + +logger = logging.getLogger("psycopg.debug") + + +class PGconnDebug: + """Wrapper for a PQconn logging all its access.""" + + _Self = TypeVar("_Self", bound="PGconnDebug") + _pgconn: "abc.PGconn" + + def __init__(self, pgconn: "abc.PGconn"): + super().__setattr__("_pgconn", pgconn) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def __getattr__(self, attr: str) -> Any: + value = getattr(self._pgconn, attr) + if callable(value): + return debugging(value) + else: + logger.info("PGconn.%s -> %s", attr, value) + return value + + def __setattr__(self, attr: str, value: Any) -> None: + setattr(self._pgconn, attr, value) + logger.info("PGconn.%s <- %s", attr, value) + + @classmethod + def connect(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect)(conninfo)) + + @classmethod + def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self: + return cls(debugging(PGconn.connect_start)(conninfo)) + + @classmethod + def ping(self, conninfo: bytes) -> int: + return debugging(PGconn.ping)(conninfo) + + +def debugging(f: Func) -> Func: + """Wrap a function in order to log its arguments and return value on call.""" + + @wraps(f) + def debugging_(*args: Any, **kwargs: Any) -> Any: + reprs = [] + for arg in args: + reprs.append(f"{arg!r}") + for (k, v) in kwargs.items(): + reprs.append(f"{k}={v!r}") + + logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs)) + rv = f(*args, **kwargs) + # Display the return value only if the function is declared to return + # something else than None. + ra = inspect.signature(f).return_annotation + if ra is not None or rv is not None: + logger.info(" <- %r", rv) + return rv + + return debugging_ # type: ignore diff --git a/psycopg/psycopg/pq/_enums.py b/psycopg/psycopg/pq/_enums.py new file mode 100644 index 0000000..e0d4018 --- /dev/null +++ b/psycopg/psycopg/pq/_enums.py @@ -0,0 +1,249 @@ +""" +libpq enum definitions for psycopg +""" + +# Copyright (C) 2020 The Psycopg Team + +from enum import IntEnum, IntFlag, auto + + +class ConnStatus(IntEnum): + """ + Current status of the connection. + """ + + __module__ = "psycopg.pq" + + OK = 0 + """The connection is in a working state.""" + BAD = auto() + """The connection is closed.""" + + STARTED = auto() + MADE = auto() + AWAITING_RESPONSE = auto() + AUTH_OK = auto() + SETENV = auto() + SSL_STARTUP = auto() + NEEDED = auto() + CHECK_WRITABLE = auto() + CONSUME = auto() + GSS_STARTUP = auto() + CHECK_TARGET = auto() + CHECK_STANDBY = auto() + + +class PollingStatus(IntEnum): + """ + The status of the socket during a connection. + + If ``READING`` or ``WRITING`` you may select before polling again. + """ + + __module__ = "psycopg.pq" + + FAILED = 0 + """Connection attempt failed.""" + READING = auto() + """Will have to wait before reading new data.""" + WRITING = auto() + """Will have to wait before writing new data.""" + OK = auto() + """Connection completed.""" + + ACTIVE = auto() + + +class ExecStatus(IntEnum): + """ + The status of a command. + """ + + __module__ = "psycopg.pq" + + EMPTY_QUERY = 0 + """The string sent to the server was empty.""" + + COMMAND_OK = auto() + """Successful completion of a command returning no data.""" + + TUPLES_OK = auto() + """ + Successful completion of a command returning data (such as a SELECT or SHOW). + """ + + COPY_OUT = auto() + """Copy Out (from server) data transfer started.""" + + COPY_IN = auto() + """Copy In (to server) data transfer started.""" + + BAD_RESPONSE = auto() + """The server's response was not understood.""" + + NONFATAL_ERROR = auto() + """A nonfatal error (a notice or warning) occurred.""" + + FATAL_ERROR = auto() + """A fatal error occurred.""" + + COPY_BOTH = auto() + """ + Copy In/Out (to and from server) data transfer started. + + This feature is currently used only for streaming replication, so this + status should not occur in ordinary applications. + """ + + SINGLE_TUPLE = auto() + """ + The PGresult contains a single result tuple from the current command. + + This status occurs only when single-row mode has been selected for the + query. + """ + + PIPELINE_SYNC = auto() + """ + The PGresult represents a synchronization point in pipeline mode, + requested by PQpipelineSync. + + This status occurs only when pipeline mode has been selected. + """ + + PIPELINE_ABORTED = auto() + """ + The PGresult represents a pipeline that has received an error from the server. + + PQgetResult must be called repeatedly, and each time it will return this + status code until the end of the current pipeline, at which point it will + return PGRES_PIPELINE_SYNC and normal processing can resume. + """ + + +class TransactionStatus(IntEnum): + """ + The transaction status of a connection. + """ + + __module__ = "psycopg.pq" + + IDLE = 0 + """Connection ready, no transaction active.""" + + ACTIVE = auto() + """A command is in progress.""" + + INTRANS = auto() + """Connection idle in an open transaction.""" + + INERROR = auto() + """An error happened in the current transaction.""" + + UNKNOWN = auto() + """Unknown connection state, broken connection.""" + + +class Ping(IntEnum): + """Response from a ping attempt.""" + + __module__ = "psycopg.pq" + + OK = 0 + """ + The server is running and appears to be accepting connections. + """ + + REJECT = auto() + """ + The server is running but is in a state that disallows connections. + """ + + NO_RESPONSE = auto() + """ + The server could not be contacted. + """ + + NO_ATTEMPT = auto() + """ + No attempt was made to contact the server. + """ + + +class PipelineStatus(IntEnum): + """Pipeline mode status of the libpq connection.""" + + __module__ = "psycopg.pq" + + OFF = 0 + """ + The libpq connection is *not* in pipeline mode. + """ + ON = auto() + """ + The libpq connection is in pipeline mode. + """ + ABORTED = auto() + """ + The libpq connection is in pipeline mode and an error occurred while + processing the current pipeline. The aborted flag is cleared when + PQgetResult returns a result of type PGRES_PIPELINE_SYNC. + """ + + +class DiagnosticField(IntEnum): + """ + Fields in an error report. + """ + + __module__ = "psycopg.pq" + + # from postgres_ext.h + SEVERITY = ord("S") + SEVERITY_NONLOCALIZED = ord("V") + SQLSTATE = ord("C") + MESSAGE_PRIMARY = ord("M") + MESSAGE_DETAIL = ord("D") + MESSAGE_HINT = ord("H") + STATEMENT_POSITION = ord("P") + INTERNAL_POSITION = ord("p") + INTERNAL_QUERY = ord("q") + CONTEXT = ord("W") + SCHEMA_NAME = ord("s") + TABLE_NAME = ord("t") + COLUMN_NAME = ord("c") + DATATYPE_NAME = ord("d") + CONSTRAINT_NAME = ord("n") + SOURCE_FILE = ord("F") + SOURCE_LINE = ord("L") + SOURCE_FUNCTION = ord("R") + + +class Format(IntEnum): + """ + Enum representing the format of a query argument or return value. + + These values are only the ones managed by the libpq. `~psycopg` may also + support automatically-chosen values: see `psycopg.adapt.PyFormat`. + """ + + __module__ = "psycopg.pq" + + TEXT = 0 + """Text parameter.""" + BINARY = 1 + """Binary parameter.""" + + +class Trace(IntFlag): + """ + Enum to control tracing of the client/server communication. + """ + + __module__ = "psycopg.pq" + + SUPPRESS_TIMESTAMPS = 1 + """Do not include timestamps in messages.""" + + REGRESS_MODE = 2 + """Redact some fields, e.g. OIDs, from messages.""" diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py new file mode 100644 index 0000000..9ca1d12 --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.py @@ -0,0 +1,804 @@ +""" +libpq access using ctypes +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import ctypes +import ctypes.util +from ctypes import Structure, CFUNCTYPE, POINTER +from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p +from typing import List, Optional, Tuple + +from .misc import find_libpq_full_path +from ..errors import NotSupportedError + +libname = find_libpq_full_path() +if not libname: + raise ImportError("libpq library not found") + +pq = ctypes.cdll.LoadLibrary(libname) + + +class FILE(Structure): + pass + + +FILE_ptr = POINTER(FILE) + +if sys.platform == "linux": + libcname = ctypes.util.find_library("c") + assert libcname + libc = ctypes.cdll.LoadLibrary(libcname) + + fdopen = libc.fdopen + fdopen.argtypes = (c_int, c_char_p) + fdopen.restype = FILE_ptr + + +# Get the libpq version to define what functions are available. + +PQlibVersion = pq.PQlibVersion +PQlibVersion.argtypes = [] +PQlibVersion.restype = c_int + +libpq_version = PQlibVersion() + + +# libpq data types + + +Oid = c_uint + + +class PGconn_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresult_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PQconninfoOption_struct(Structure): + _fields_ = [ + ("keyword", c_char_p), + ("envvar", c_char_p), + ("compiled", c_char_p), + ("val", c_char_p), + ("label", c_char_p), + ("dispchar", c_char_p), + ("dispsize", c_int), + ] + + +class PGnotify_struct(Structure): + _fields_ = [ + ("relname", c_char_p), + ("be_pid", c_int), + ("extra", c_char_p), + ] + + +class PGcancel_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + +class PGresAttDesc_struct(Structure): + _fields_ = [ + ("name", c_char_p), + ("tableid", Oid), + ("columnid", c_int), + ("format", c_int), + ("typid", Oid), + ("typlen", c_int), + ("atttypmod", c_int), + ] + + +PGconn_ptr = POINTER(PGconn_struct) +PGresult_ptr = POINTER(PGresult_struct) +PQconninfoOption_ptr = POINTER(PQconninfoOption_struct) +PGnotify_ptr = POINTER(PGnotify_struct) +PGcancel_ptr = POINTER(PGcancel_struct) +PGresAttDesc_ptr = POINTER(PGresAttDesc_struct) + + +# Function definitions as explained in PostgreSQL 12 documentation + +# 33.1. Database Connection Control Functions + +# PQconnectdbParams: doesn't seem useful, won't wrap for now + +PQconnectdb = pq.PQconnectdb +PQconnectdb.argtypes = [c_char_p] +PQconnectdb.restype = PGconn_ptr + +# PQsetdbLogin: not useful +# PQsetdb: not useful + +# PQconnectStartParams: not useful + +PQconnectStart = pq.PQconnectStart +PQconnectStart.argtypes = [c_char_p] +PQconnectStart.restype = PGconn_ptr + +PQconnectPoll = pq.PQconnectPoll +PQconnectPoll.argtypes = [PGconn_ptr] +PQconnectPoll.restype = c_int + +PQconndefaults = pq.PQconndefaults +PQconndefaults.argtypes = [] +PQconndefaults.restype = PQconninfoOption_ptr + +PQconninfoFree = pq.PQconninfoFree +PQconninfoFree.argtypes = [PQconninfoOption_ptr] +PQconninfoFree.restype = None + +PQconninfo = pq.PQconninfo +PQconninfo.argtypes = [PGconn_ptr] +PQconninfo.restype = PQconninfoOption_ptr + +PQconninfoParse = pq.PQconninfoParse +PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)] +PQconninfoParse.restype = PQconninfoOption_ptr + +PQfinish = pq.PQfinish +PQfinish.argtypes = [PGconn_ptr] +PQfinish.restype = None + +PQreset = pq.PQreset +PQreset.argtypes = [PGconn_ptr] +PQreset.restype = None + +PQresetStart = pq.PQresetStart +PQresetStart.argtypes = [PGconn_ptr] +PQresetStart.restype = c_int + +PQresetPoll = pq.PQresetPoll +PQresetPoll.argtypes = [PGconn_ptr] +PQresetPoll.restype = c_int + +PQping = pq.PQping +PQping.argtypes = [c_char_p] +PQping.restype = c_int + + +# 33.2. Connection Status Functions + +PQdb = pq.PQdb +PQdb.argtypes = [PGconn_ptr] +PQdb.restype = c_char_p + +PQuser = pq.PQuser +PQuser.argtypes = [PGconn_ptr] +PQuser.restype = c_char_p + +PQpass = pq.PQpass +PQpass.argtypes = [PGconn_ptr] +PQpass.restype = c_char_p + +PQhost = pq.PQhost +PQhost.argtypes = [PGconn_ptr] +PQhost.restype = c_char_p + +_PQhostaddr = None + +if libpq_version >= 120000: + _PQhostaddr = pq.PQhostaddr + _PQhostaddr.argtypes = [PGconn_ptr] + _PQhostaddr.restype = c_char_p + + +def PQhostaddr(pgconn: PGconn_struct) -> bytes: + if not _PQhostaddr: + raise NotSupportedError( + "PQhostaddr requires libpq from PostgreSQL 12," + f" {libpq_version} available instead" + ) + + return _PQhostaddr(pgconn) + + +PQport = pq.PQport +PQport.argtypes = [PGconn_ptr] +PQport.restype = c_char_p + +PQtty = pq.PQtty +PQtty.argtypes = [PGconn_ptr] +PQtty.restype = c_char_p + +PQoptions = pq.PQoptions +PQoptions.argtypes = [PGconn_ptr] +PQoptions.restype = c_char_p + +PQstatus = pq.PQstatus +PQstatus.argtypes = [PGconn_ptr] +PQstatus.restype = c_int + +PQtransactionStatus = pq.PQtransactionStatus +PQtransactionStatus.argtypes = [PGconn_ptr] +PQtransactionStatus.restype = c_int + +PQparameterStatus = pq.PQparameterStatus +PQparameterStatus.argtypes = [PGconn_ptr, c_char_p] +PQparameterStatus.restype = c_char_p + +PQprotocolVersion = pq.PQprotocolVersion +PQprotocolVersion.argtypes = [PGconn_ptr] +PQprotocolVersion.restype = c_int + +PQserverVersion = pq.PQserverVersion +PQserverVersion.argtypes = [PGconn_ptr] +PQserverVersion.restype = c_int + +PQerrorMessage = pq.PQerrorMessage +PQerrorMessage.argtypes = [PGconn_ptr] +PQerrorMessage.restype = c_char_p + +PQsocket = pq.PQsocket +PQsocket.argtypes = [PGconn_ptr] +PQsocket.restype = c_int + +PQbackendPID = pq.PQbackendPID +PQbackendPID.argtypes = [PGconn_ptr] +PQbackendPID.restype = c_int + +PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword +PQconnectionNeedsPassword.argtypes = [PGconn_ptr] +PQconnectionNeedsPassword.restype = c_int + +PQconnectionUsedPassword = pq.PQconnectionUsedPassword +PQconnectionUsedPassword.argtypes = [PGconn_ptr] +PQconnectionUsedPassword.restype = c_int + +PQsslInUse = pq.PQsslInUse +PQsslInUse.argtypes = [PGconn_ptr] +PQsslInUse.restype = c_int + +# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl + + +# 33.3. Command Execution Functions + +PQexec = pq.PQexec +PQexec.argtypes = [PGconn_ptr, c_char_p] +PQexec.restype = PGresult_ptr + +PQexecParams = pq.PQexecParams +PQexecParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecParams.restype = PGresult_ptr + +PQprepare = pq.PQprepare +PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQprepare.restype = PGresult_ptr + +PQexecPrepared = pq.PQexecPrepared +PQexecPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQexecPrepared.restype = PGresult_ptr + +PQdescribePrepared = pq.PQdescribePrepared +PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQdescribePrepared.restype = PGresult_ptr + +PQdescribePortal = pq.PQdescribePortal +PQdescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQdescribePortal.restype = PGresult_ptr + +PQresultStatus = pq.PQresultStatus +PQresultStatus.argtypes = [PGresult_ptr] +PQresultStatus.restype = c_int + +# PQresStatus: not needed, we have pretty enums + +PQresultErrorMessage = pq.PQresultErrorMessage +PQresultErrorMessage.argtypes = [PGresult_ptr] +PQresultErrorMessage.restype = c_char_p + +# TODO: PQresultVerboseErrorMessage + +PQresultErrorField = pq.PQresultErrorField +PQresultErrorField.argtypes = [PGresult_ptr, c_int] +PQresultErrorField.restype = c_char_p + +PQclear = pq.PQclear +PQclear.argtypes = [PGresult_ptr] +PQclear.restype = None + + +# 33.3.2. Retrieving Query Result Information + +PQntuples = pq.PQntuples +PQntuples.argtypes = [PGresult_ptr] +PQntuples.restype = c_int + +PQnfields = pq.PQnfields +PQnfields.argtypes = [PGresult_ptr] +PQnfields.restype = c_int + +PQfname = pq.PQfname +PQfname.argtypes = [PGresult_ptr, c_int] +PQfname.restype = c_char_p + +# PQfnumber: useless and hard to use + +PQftable = pq.PQftable +PQftable.argtypes = [PGresult_ptr, c_int] +PQftable.restype = Oid + +PQftablecol = pq.PQftablecol +PQftablecol.argtypes = [PGresult_ptr, c_int] +PQftablecol.restype = c_int + +PQfformat = pq.PQfformat +PQfformat.argtypes = [PGresult_ptr, c_int] +PQfformat.restype = c_int + +PQftype = pq.PQftype +PQftype.argtypes = [PGresult_ptr, c_int] +PQftype.restype = Oid + +PQfmod = pq.PQfmod +PQfmod.argtypes = [PGresult_ptr, c_int] +PQfmod.restype = c_int + +PQfsize = pq.PQfsize +PQfsize.argtypes = [PGresult_ptr, c_int] +PQfsize.restype = c_int + +PQbinaryTuples = pq.PQbinaryTuples +PQbinaryTuples.argtypes = [PGresult_ptr] +PQbinaryTuples.restype = c_int + +PQgetvalue = pq.PQgetvalue +PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int] +PQgetvalue.restype = POINTER(c_char) # not a null-terminated string + +PQgetisnull = pq.PQgetisnull +PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int] +PQgetisnull.restype = c_int + +PQgetlength = pq.PQgetlength +PQgetlength.argtypes = [PGresult_ptr, c_int, c_int] +PQgetlength.restype = c_int + +PQnparams = pq.PQnparams +PQnparams.argtypes = [PGresult_ptr] +PQnparams.restype = c_int + +PQparamtype = pq.PQparamtype +PQparamtype.argtypes = [PGresult_ptr, c_int] +PQparamtype.restype = Oid + +# PQprint: pretty useless + +# 33.3.3. Retrieving Other Result Information + +PQcmdStatus = pq.PQcmdStatus +PQcmdStatus.argtypes = [PGresult_ptr] +PQcmdStatus.restype = c_char_p + +PQcmdTuples = pq.PQcmdTuples +PQcmdTuples.argtypes = [PGresult_ptr] +PQcmdTuples.restype = c_char_p + +PQoidValue = pq.PQoidValue +PQoidValue.argtypes = [PGresult_ptr] +PQoidValue.restype = Oid + + +# 33.3.4. Escaping Strings for Inclusion in SQL Commands + +PQescapeLiteral = pq.PQescapeLiteral +PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeLiteral.restype = POINTER(c_char) + +PQescapeIdentifier = pq.PQescapeIdentifier +PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeIdentifier.restype = POINTER(c_char) + +PQescapeStringConn = pq.PQescapeStringConn +# TODO: raises "wrong type" error +# PQescapeStringConn.argtypes = [ +# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int) +# ] +PQescapeStringConn.restype = c_size_t + +PQescapeString = pq.PQescapeString +# TODO: raises "wrong type" error +# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t] +PQescapeString.restype = c_size_t + +PQescapeByteaConn = pq.PQescapeByteaConn +PQescapeByteaConn.argtypes = [ + PGconn_ptr, + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeByteaConn.restype = POINTER(c_ubyte) + +PQescapeBytea = pq.PQescapeBytea +PQescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + c_size_t, + POINTER(c_size_t), +] +PQescapeBytea.restype = POINTER(c_ubyte) + + +PQunescapeBytea = pq.PQunescapeBytea +PQunescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + POINTER(c_size_t), +] +PQunescapeBytea.restype = POINTER(c_ubyte) + + +# 33.4. Asynchronous Command Processing + +PQsendQuery = pq.PQsendQuery +PQsendQuery.argtypes = [PGconn_ptr, c_char_p] +PQsendQuery.restype = c_int + +PQsendQueryParams = pq.PQsendQueryParams +PQsendQueryParams.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(Oid), + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryParams.restype = c_int + +PQsendPrepare = pq.PQsendPrepare +PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)] +PQsendPrepare.restype = c_int + +PQsendQueryPrepared = pq.PQsendQueryPrepared +PQsendQueryPrepared.argtypes = [ + PGconn_ptr, + c_char_p, + c_int, + POINTER(c_char_p), + POINTER(c_int), + POINTER(c_int), + c_int, +] +PQsendQueryPrepared.restype = c_int + +PQsendDescribePrepared = pq.PQsendDescribePrepared +PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePrepared.restype = c_int + +PQsendDescribePortal = pq.PQsendDescribePortal +PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p] +PQsendDescribePortal.restype = c_int + +PQgetResult = pq.PQgetResult +PQgetResult.argtypes = [PGconn_ptr] +PQgetResult.restype = PGresult_ptr + +PQconsumeInput = pq.PQconsumeInput +PQconsumeInput.argtypes = [PGconn_ptr] +PQconsumeInput.restype = c_int + +PQisBusy = pq.PQisBusy +PQisBusy.argtypes = [PGconn_ptr] +PQisBusy.restype = c_int + +PQsetnonblocking = pq.PQsetnonblocking +PQsetnonblocking.argtypes = [PGconn_ptr, c_int] +PQsetnonblocking.restype = c_int + +PQisnonblocking = pq.PQisnonblocking +PQisnonblocking.argtypes = [PGconn_ptr] +PQisnonblocking.restype = c_int + +PQflush = pq.PQflush +PQflush.argtypes = [PGconn_ptr] +PQflush.restype = c_int + + +# 33.5. Retrieving Query Results Row-by-Row +PQsetSingleRowMode = pq.PQsetSingleRowMode +PQsetSingleRowMode.argtypes = [PGconn_ptr] +PQsetSingleRowMode.restype = c_int + + +# 33.6. Canceling Queries in Progress + +PQgetCancel = pq.PQgetCancel +PQgetCancel.argtypes = [PGconn_ptr] +PQgetCancel.restype = PGcancel_ptr + +PQfreeCancel = pq.PQfreeCancel +PQfreeCancel.argtypes = [PGcancel_ptr] +PQfreeCancel.restype = None + +PQcancel = pq.PQcancel +# TODO: raises "wrong type" error +# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int] +PQcancel.restype = c_int + + +# 33.8. Asynchronous Notification + +PQnotifies = pq.PQnotifies +PQnotifies.argtypes = [PGconn_ptr] +PQnotifies.restype = PGnotify_ptr + + +# 33.9. Functions Associated with the COPY Command + +PQputCopyData = pq.PQputCopyData +PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int] +PQputCopyData.restype = c_int + +PQputCopyEnd = pq.PQputCopyEnd +PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p] +PQputCopyEnd.restype = c_int + +PQgetCopyData = pq.PQgetCopyData +PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int] +PQgetCopyData.restype = c_int + + +# 33.10. Control Functions + +PQtrace = pq.PQtrace +PQtrace.argtypes = [PGconn_ptr, FILE_ptr] +PQtrace.restype = None + +_PQsetTraceFlags = None + +if libpq_version >= 140000: + _PQsetTraceFlags = pq.PQsetTraceFlags + _PQsetTraceFlags.argtypes = [PGconn_ptr, c_int] + _PQsetTraceFlags.restype = None + + +def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None: + if not _PQsetTraceFlags: + raise NotSupportedError( + "PQsetTraceFlags requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + + _PQsetTraceFlags(pgconn, flags) + + +PQuntrace = pq.PQuntrace +PQuntrace.argtypes = [PGconn_ptr] +PQuntrace.restype = None + +# 33.11. Miscellaneous Functions + +PQfreemem = pq.PQfreemem +PQfreemem.argtypes = [c_void_p] +PQfreemem.restype = None + +if libpq_version >= 100000: + _PQencryptPasswordConn = pq.PQencryptPasswordConn + _PQencryptPasswordConn.argtypes = [ + PGconn_ptr, + c_char_p, + c_char_p, + c_char_p, + ] + _PQencryptPasswordConn.restype = POINTER(c_char) + + +def PQencryptPasswordConn( + pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes +) -> Optional[bytes]: + if not _PQencryptPasswordConn: + raise NotSupportedError( + "PQencryptPasswordConn requires libpq from PostgreSQL 10," + f" {libpq_version} available instead" + ) + + return _PQencryptPasswordConn(pgconn, passwd, user, algorithm) + + +PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult +PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int] +PQmakeEmptyPGresult.restype = PGresult_ptr + +PQsetResultAttrs = pq.PQsetResultAttrs +PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr] +PQsetResultAttrs.restype = c_int + + +# 33.12. Notice Processing + +PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr) + +PQsetNoticeReceiver = pq.PQsetNoticeReceiver +PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p] +PQsetNoticeReceiver.restype = PQnoticeReceiver + +# 34.5 Pipeline Mode + +_PQpipelineStatus = None +_PQenterPipelineMode = None +_PQexitPipelineMode = None +_PQpipelineSync = None +_PQsendFlushRequest = None + +if libpq_version >= 140000: + _PQpipelineStatus = pq.PQpipelineStatus + _PQpipelineStatus.argtypes = [PGconn_ptr] + _PQpipelineStatus.restype = c_int + + _PQenterPipelineMode = pq.PQenterPipelineMode + _PQenterPipelineMode.argtypes = [PGconn_ptr] + _PQenterPipelineMode.restype = c_int + + _PQexitPipelineMode = pq.PQexitPipelineMode + _PQexitPipelineMode.argtypes = [PGconn_ptr] + _PQexitPipelineMode.restype = c_int + + _PQpipelineSync = pq.PQpipelineSync + _PQpipelineSync.argtypes = [PGconn_ptr] + _PQpipelineSync.restype = c_int + + _PQsendFlushRequest = pq.PQsendFlushRequest + _PQsendFlushRequest.argtypes = [PGconn_ptr] + _PQsendFlushRequest.restype = c_int + + +def PQpipelineStatus(pgconn: PGconn_struct) -> int: + if not _PQpipelineStatus: + raise NotSupportedError( + "PQpipelineStatus requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineStatus(pgconn) + + +def PQenterPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQenterPipelineMode: + raise NotSupportedError( + "PQenterPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQenterPipelineMode(pgconn) + + +def PQexitPipelineMode(pgconn: PGconn_struct) -> int: + if not _PQexitPipelineMode: + raise NotSupportedError( + "PQexitPipelineMode requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQexitPipelineMode(pgconn) + + +def PQpipelineSync(pgconn: PGconn_struct) -> int: + if not _PQpipelineSync: + raise NotSupportedError( + "PQpipelineSync requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQpipelineSync(pgconn) + + +def PQsendFlushRequest(pgconn: PGconn_struct) -> int: + if not _PQsendFlushRequest: + raise NotSupportedError( + "PQsendFlushRequest requires libpq from PostgreSQL 14," + f" {libpq_version} available instead" + ) + return _PQsendFlushRequest(pgconn) + + +# 33.18. SSL Support + +PQinitOpenSSL = pq.PQinitOpenSSL +PQinitOpenSSL.argtypes = [c_int, c_int] +PQinitOpenSSL.restype = None + + +def generate_stub() -> None: + import re + from ctypes import _CFuncPtr # type: ignore + + def type2str(fname, narg, t): + if t is None: + return "None" + elif t is c_void_p: + return "Any" + elif t is c_int or t is c_uint or t is c_size_t: + return "int" + elif t is c_char_p or t.__name__ == "LP_c_char": + if narg is not None: + return "bytes" + else: + return "Optional[bytes]" + + elif t.__name__ in ( + "LP_PGconn_struct", + "LP_PGresult_struct", + "LP_PGcancel_struct", + ): + if narg is not None: + return f"Optional[{t.__name__[3:]}]" + else: + return t.__name__[3:] + + elif t.__name__ in ("LP_PQconninfoOption_struct",): + return f"Sequence[{t.__name__[3:]}]" + + elif t.__name__ in ( + "LP_c_ubyte", + "LP_c_char_p", + "LP_c_int", + "LP_c_uint", + "LP_c_ulong", + "LP_FILE", + ): + return f"_Pointer[{t.__name__[3:]}]" + + else: + assert False, f"can't deal with {t} in {fname}" + + fn = __file__ + "i" + with open(fn) as f: + lines = f.read().splitlines() + + istart, iend = ( + i + for i, line in enumerate(lines) + if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line) + ) + + known = { + line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ") + } + + signatures = [] + + for name, obj in globals().items(): + if name in known: + continue + if not isinstance(obj, _CFuncPtr): + continue + + params = [] + for i, t in enumerate(obj.argtypes): + params.append(f"arg{i + 1}: {type2str(name, i, t)}") + + resname = type2str(name, None, obj.restype) + + signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...") + + lines[istart + 1 : iend] = signatures + + with open(fn, "w") as f: + f.write("\n".join(lines)) + f.write("\n") + + +if __name__ == "__main__": + generate_stub() diff --git a/psycopg/psycopg/pq/_pq_ctypes.pyi b/psycopg/psycopg/pq/_pq_ctypes.pyi new file mode 100644 index 0000000..5d2ee3f --- /dev/null +++ b/psycopg/psycopg/pq/_pq_ctypes.pyi @@ -0,0 +1,216 @@ +""" +types stub for ctypes functions +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Optional, Sequence +from ctypes import Array, pointer, _Pointer +from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong + +class FILE: ... + +def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var] + +Oid = c_uint + +class PGconn_struct: ... +class PGresult_struct: ... +class PGcancel_struct: ... + +class PQconninfoOption_struct: + keyword: bytes + envvar: bytes + compiled: bytes + val: bytes + label: bytes + dispchar: bytes + dispsize: int + +class PGnotify_struct: + be_pid: int + relname: bytes + extra: bytes + +class PGresAttDesc_struct: + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + +def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ... +def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQexecPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> PGresult_struct: ... +def PQprepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> PGresult_struct: ... +def PQgetvalue( + arg1: Optional[PGresult_struct], arg2: int, arg3: int +) -> _Pointer[c_char]: ... +def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQescapeStringConn( + arg1: Optional[PGconn_struct], + arg2: c_char_p, + arg3: bytes, + arg4: int, + arg5: _Pointer[c_int], +) -> int: ... +def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ... +def PQsendPrepare( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: int, + arg5: Optional[Array[c_uint]], +) -> int: ... +def PQsendQueryPrepared( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: int, + arg4: Optional[Array[c_char_p]], + arg5: Optional[Array[c_int]], + arg6: Optional[Array[c_int]], + arg7: int, +) -> int: ... +def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ... +def PQsetNoticeReceiver( + arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any +) -> Callable[[Any], PGresult_struct]: ... + +# TODO: Ignoring type as getting an error on mypy/ctypes: +# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be +# a subtype of "ctypes._CData" +def PQnotifies( + arg1: Optional[PGconn_struct], +) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore +def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ... + +# Arg 2 is a _Pointer, reported as _CArgObject by mypy +def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ... +def PQsetResultAttrs( + arg1: Optional[PGresult_struct], + arg2: int, + arg3: Array[PGresAttDesc_struct], # type: ignore +) -> int: ... +def PQtrace( + arg1: Optional[PGconn_struct], + arg2: _Pointer[FILE], # type: ignore[type-var] +) -> None: ... +def PQencryptPasswordConn( + arg1: Optional[PGconn_struct], + arg2: bytes, + arg3: bytes, + arg4: Optional[bytes], +) -> bytes: ... +def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ... +def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ... +def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ... +def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ... + +# fmt: off +# autogenerated: start +def PQlibVersion() -> int: ... +def PQconnectdb(arg1: bytes) -> PGconn_struct: ... +def PQconnectStart(arg1: bytes) -> PGconn_struct: ... +def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ... +def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ... +def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ... +def PQfinish(arg1: Optional[PGconn_struct]) -> None: ... +def PQreset(arg1: Optional[PGconn_struct]) -> None: ... +def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ... +def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ... +def PQping(arg1: bytes) -> int: ... +def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ... +def PQstatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ... +def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ... +def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ... +def PQsocket(arg1: Optional[PGconn_struct]) -> int: ... +def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ... +def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ... +def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ... +def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ... +def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ... +def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQclear(arg1: Optional[PGresult_struct]) -> None: ... +def PQntuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQnfields(arg1: Optional[PGresult_struct]) -> int: ... +def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ... +def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ... +def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... +def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... +def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... +def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... +def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... +def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ... +def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ... +def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... +def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ... +def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ... +def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ... +def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ... +def PQflush(arg1: Optional[PGconn_struct]) -> int: ... +def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ... +def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ... +def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ... +def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ... +def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ... +def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ... +def PQfreemem(arg1: Any) -> None: ... +def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ... +def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ... +def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ... +def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ... +def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ... +def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ... +def PQinitOpenSSL(arg1: int, arg2: int) -> None: ... +# autogenerated: end +# fmt: on + +# vim: set syntax=python: diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py new file mode 100644 index 0000000..9c45f64 --- /dev/null +++ b/psycopg/psycopg/pq/abc.py @@ -0,0 +1,385 @@ +""" +Protocol objects to represent objects exposed by different pq implementations. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from ._enums import Format, Trace +from .._compat import Protocol + +if TYPE_CHECKING: + from .misc import PGnotify, ConninfoOption, PGresAttDesc + +# An object implementing the buffer protocol (ish) +Buffer: TypeAlias = Union[bytes, bytearray, memoryview] + + +class PGconn(Protocol): + + notice_handler: Optional[Callable[["PGresult"], None]] + notify_handler: Optional[Callable[["PGnotify"], None]] + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + ... + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + ... + + def connect_poll(self) -> int: + ... + + def finish(self) -> None: + ... + + @property + def info(self) -> List["ConninfoOption"]: + ... + + def reset(self) -> None: + ... + + def reset_start(self) -> None: + ... + + def reset_poll(self) -> int: + ... + + @classmethod + def ping(self, conninfo: bytes) -> int: + ... + + @property + def db(self) -> bytes: + ... + + @property + def user(self) -> bytes: + ... + + @property + def password(self) -> bytes: + ... + + @property + def host(self) -> bytes: + ... + + @property + def hostaddr(self) -> bytes: + ... + + @property + def port(self) -> bytes: + ... + + @property + def tty(self) -> bytes: + ... + + @property + def options(self) -> bytes: + ... + + @property + def status(self) -> int: + ... + + @property + def transaction_status(self) -> int: + ... + + def parameter_status(self, name: bytes) -> Optional[bytes]: + ... + + @property + def error_message(self) -> bytes: + ... + + @property + def server_version(self) -> int: + ... + + @property + def socket(self) -> int: + ... + + @property + def backend_pid(self) -> int: + ... + + @property + def needs_password(self) -> bool: + ... + + @property + def used_password(self) -> bool: + ... + + @property + def ssl_in_use(self) -> bool: + ... + + def exec_(self, command: bytes) -> "PGresult": + ... + + def send_query(self, command: bytes) -> None: + ... + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + ... + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + ... + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional[Buffer]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + ... + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + ... + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Buffer]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + ... + + def describe_prepared(self, name: bytes) -> "PGresult": + ... + + def send_describe_prepared(self, name: bytes) -> None: + ... + + def describe_portal(self, name: bytes) -> "PGresult": + ... + + def send_describe_portal(self, name: bytes) -> None: + ... + + def get_result(self) -> Optional["PGresult"]: + ... + + def consume_input(self) -> None: + ... + + def is_busy(self) -> int: + ... + + @property + def nonblocking(self) -> int: + ... + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + ... + + def flush(self) -> int: + ... + + def set_single_row_mode(self) -> None: + ... + + def get_cancel(self) -> "PGcancel": + ... + + def notifies(self) -> Optional["PGnotify"]: + ... + + def put_copy_data(self, buffer: Buffer) -> int: + ... + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + ... + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + ... + + def trace(self, fileno: int) -> None: + ... + + def set_trace_flags(self, flags: Trace) -> None: + ... + + def untrace(self) -> None: + ... + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + ... + + def make_empty_result(self, exec_status: int) -> "PGresult": + ... + + @property + def pipeline_status(self) -> int: + ... + + def enter_pipeline_mode(self) -> None: + ... + + def exit_pipeline_mode(self) -> None: + ... + + def pipeline_sync(self) -> None: + ... + + def send_flush_request(self) -> None: + ... + + +class PGresult(Protocol): + def clear(self) -> None: + ... + + @property + def status(self) -> int: + ... + + @property + def error_message(self) -> bytes: + ... + + def error_field(self, fieldcode: int) -> Optional[bytes]: + ... + + @property + def ntuples(self) -> int: + ... + + @property + def nfields(self) -> int: + ... + + def fname(self, column_number: int) -> Optional[bytes]: + ... + + def ftable(self, column_number: int) -> int: + ... + + def ftablecol(self, column_number: int) -> int: + ... + + def fformat(self, column_number: int) -> int: + ... + + def ftype(self, column_number: int) -> int: + ... + + def fmod(self, column_number: int) -> int: + ... + + def fsize(self, column_number: int) -> int: + ... + + @property + def binary_tuples(self) -> int: + ... + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + ... + + @property + def nparams(self) -> int: + ... + + def param_type(self, param_number: int) -> int: + ... + + @property + def command_status(self) -> Optional[bytes]: + ... + + @property + def command_tuples(self) -> Optional[int]: + ... + + @property + def oid_value(self) -> int: + ... + + def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: + ... + + +class PGcancel(Protocol): + def free(self) -> None: + ... + + def cancel(self) -> None: + ... + + +class Conninfo(Protocol): + @classmethod + def get_defaults(cls) -> List["ConninfoOption"]: + ... + + @classmethod + def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: + ... + + @classmethod + def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: + ... + + +class Escaping(Protocol): + def __init__(self, conn: Optional[PGconn] = None): + ... + + def escape_literal(self, data: Buffer) -> bytes: + ... + + def escape_identifier(self, data: Buffer) -> bytes: + ... + + def escape_string(self, data: Buffer) -> bytes: + ... + + def escape_bytea(self, data: Buffer) -> bytes: + ... + + def unescape_bytea(self, data: Buffer) -> bytes: + ... diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py new file mode 100644 index 0000000..3a43133 --- /dev/null +++ b/psycopg/psycopg/pq/misc.py @@ -0,0 +1,146 @@ +""" +Various functionalities to make easier to work with the libpq. +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +import sys +import logging +import ctypes.util +from typing import cast, NamedTuple, Optional, Union + +from .abc import PGconn, PGresult +from ._enums import ConnStatus, TransactionStatus, PipelineStatus +from .._compat import cache +from .._encodings import pgconn_encoding + +logger = logging.getLogger("psycopg.pq") + +OK = ConnStatus.OK + + +class PGnotify(NamedTuple): + relname: bytes + be_pid: int + extra: bytes + + +class ConninfoOption(NamedTuple): + keyword: bytes + envvar: Optional[bytes] + compiled: Optional[bytes] + val: Optional[bytes] + label: bytes + dispchar: bytes + dispsize: int + + +class PGresAttDesc(NamedTuple): + name: bytes + tableid: int + columnid: int + format: int + typid: int + typlen: int + atttypmod: int + + +@cache +def find_libpq_full_path() -> Optional[str]: + if sys.platform == "win32": + libname = ctypes.util.find_library("libpq.dll") + + elif sys.platform == "darwin": + libname = ctypes.util.find_library("libpq.dylib") + # (hopefully) temporary hack: libpq not in a standard place + # https://github.com/orgs/Homebrew/discussions/3595 + # If pg_config is available and agrees, let's use its indications. + if not libname: + try: + import subprocess as sp + + libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode() + libname = os.path.join(libdir, "libpq.dylib") + if not os.path.exists(libname): + libname = None + except Exception as ex: + logger.debug("couldn't use pg_config to find libpq: %s", ex) + + else: + libname = ctypes.util.find_library("pq") + + return libname + + +def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str: + """ + Return an error message from a `PGconn` or `PGresult`. + + The return value is a `!str` (unlike pq data which is usually `!bytes`): + use the connection encoding if available, otherwise the `!encoding` + parameter as a fallback for decoding. Don't raise exceptions on decoding + errors. + + """ + bmsg: bytes + + if hasattr(obj, "error_field"): + # obj is a PGresult + obj = cast(PGresult, obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + elif hasattr(obj, "error_message"): + # obj is a PGconn + if obj.status == OK: + encoding = pgconn_encoding(obj) + bmsg = obj.error_message + + # strip severity and whitespaces + if bmsg: + bmsg = bmsg.split(b":", 1)[-1].strip() + + else: + raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}") + + if bmsg: + msg = bmsg.decode(encoding, "replace") + else: + msg = "no details available" + + return msg + + +def connection_summary(pgconn: PGconn) -> str: + """ + Return summary information on a connection. + + Useful for __repr__ + """ + parts = [] + if pgconn.status == OK: + # Put together the [STATUS] + status = TransactionStatus(pgconn.transaction_status).name + if pgconn.pipeline_status: + status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}" + + # Put together the (CONNECTION) + if not pgconn.host.startswith(b"/"): + parts.append(("host", pgconn.host.decode())) + if pgconn.port != b"5432": + parts.append(("port", pgconn.port.decode())) + if pgconn.user != pgconn.db: + parts.append(("user", pgconn.user.decode())) + parts.append(("database", pgconn.db.decode())) + + else: + status = ConnStatus(pgconn.status).name + + sparts = " ".join("%s=%s" % part for part in parts) + if sparts: + sparts = f" ({sparts})" + return f"[{status}]{sparts}" diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py new file mode 100644 index 0000000..8b87c19 --- /dev/null +++ b/psycopg/psycopg/pq/pq_ctypes.py @@ -0,0 +1,1086 @@ +""" +libpq Python wrapper using ctypes bindings. + +Clients shouldn't use this module directly, unless for testing: they should use +the `pq` module instead, which is in charge of choosing the best +implementation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import sys +import logging +from os import getpid +from weakref import ref + +from ctypes import Array, POINTER, cast, string_at, create_string_buffer, byref +from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object +from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import cast as t_cast, TYPE_CHECKING + +from .. import errors as e +from . import _pq_ctypes as impl +from .misc import PGnotify, ConninfoOption, PGresAttDesc +from .misc import error_message, connection_summary +from ._enums import Format, ExecStatus, Trace + +# Imported locally to call them from __del__ methods +from ._pq_ctypes import PQclear, PQfinish, PQfreeCancel, PQstatus + +if TYPE_CHECKING: + from . import abc + +__impl__ = "python" + +logger = logging.getLogger("psycopg") + + +def version() -> int: + """Return the version number of the libpq currently loaded. + + The number is in the same format of `~psycopg.ConnectionInfo.server_version`. + + Certain features might not be available if the libpq library used is too old. + """ + return impl.PQlibVersion() + + +@impl.PQnoticeReceiver # type: ignore +def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None: + pgconn = cast(arg, POINTER(py_object)).contents.value() + if not (pgconn and pgconn.notice_handler): + return + + res = PGresult(result_ptr) + try: + pgconn.notice_handler(res) + except Exception as exc: + logger.exception("error in notice receiver: %s", exc) + finally: + res._pgresult_ptr = None # avoid destroying the pgresult_ptr + + +class PGconn: + """ + Python representation of a libpq connection. + """ + + __slots__ = ( + "_pgconn_ptr", + "notice_handler", + "notify_handler", + "_self_ptr", + "_procpid", + "__weakref__", + ) + + def __init__(self, pgconn_ptr: impl.PGconn_struct): + self._pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr + self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None + self.notify_handler: Optional[Callable[[PGnotify], None]] = None + + # Keep alive for the lifetime of PGconn + self._self_ptr = py_object(ref(self)) + impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr)) + + self._procpid = getpid() + + def __del__(self) -> None: + # Close the connection only if it was created in this process, + # not if this object is being GC'd after fork. + if getpid() == self._procpid: + self.finish() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self) + return f"<{cls} {info} at 0x{id(self):x}>" + + @classmethod + def connect(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectdb(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + @classmethod + def connect_start(cls, conninfo: bytes) -> "PGconn": + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + pgconn_ptr = impl.PQconnectStart(conninfo) + if not pgconn_ptr: + raise MemoryError("couldn't allocate PGconn") + return cls(pgconn_ptr) + + def connect_poll(self) -> int: + return self._call_int(impl.PQconnectPoll) + + def finish(self) -> None: + self._pgconn_ptr, p = None, self._pgconn_ptr + if p: + PQfinish(p) + + @property + def pgconn_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGconn` structure, as integer. + + `!None` if the connection is closed. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgconn_ptr is None: + return None + + return addressof(self._pgconn_ptr.contents) # type: ignore[attr-defined] + + @property + def info(self) -> List["ConninfoOption"]: + self._ensure_pgconn() + opts = impl.PQconninfo(self._pgconn_ptr) + if not opts: + raise MemoryError("couldn't allocate connection info") + try: + return Conninfo._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + def reset(self) -> None: + self._ensure_pgconn() + impl.PQreset(self._pgconn_ptr) + + def reset_start(self) -> None: + if not impl.PQresetStart(self._pgconn_ptr): + raise e.OperationalError("couldn't reset connection") + + def reset_poll(self) -> int: + return self._call_int(impl.PQresetPoll) + + @classmethod + def ping(self, conninfo: bytes) -> int: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + return impl.PQping(conninfo) + + @property + def db(self) -> bytes: + return self._call_bytes(impl.PQdb) + + @property + def user(self) -> bytes: + return self._call_bytes(impl.PQuser) + + @property + def password(self) -> bytes: + return self._call_bytes(impl.PQpass) + + @property + def host(self) -> bytes: + return self._call_bytes(impl.PQhost) + + @property + def hostaddr(self) -> bytes: + return self._call_bytes(impl.PQhostaddr) + + @property + def port(self) -> bytes: + return self._call_bytes(impl.PQport) + + @property + def tty(self) -> bytes: + return self._call_bytes(impl.PQtty) + + @property + def options(self) -> bytes: + return self._call_bytes(impl.PQoptions) + + @property + def status(self) -> int: + return PQstatus(self._pgconn_ptr) + + @property + def transaction_status(self) -> int: + return impl.PQtransactionStatus(self._pgconn_ptr) + + def parameter_status(self, name: bytes) -> Optional[bytes]: + self._ensure_pgconn() + return impl.PQparameterStatus(self._pgconn_ptr, name) + + @property + def error_message(self) -> bytes: + return impl.PQerrorMessage(self._pgconn_ptr) + + @property + def protocol_version(self) -> int: + return self._call_int(impl.PQprotocolVersion) + + @property + def server_version(self) -> int: + return self._call_int(impl.PQserverVersion) + + @property + def socket(self) -> int: + rv = self._call_int(impl.PQsocket) + if rv == -1: + raise e.OperationalError("the connection is lost") + return rv + + @property + def backend_pid(self) -> int: + return self._call_int(impl.PQbackendPID) + + @property + def needs_password(self) -> bool: + """True if the connection authentication method required a password, + but none was available. + + See :pq:`PQconnectionNeedsPassword` for details. + """ + return bool(impl.PQconnectionNeedsPassword(self._pgconn_ptr)) + + @property + def used_password(self) -> bool: + """True if the connection authentication method used a password. + + See :pq:`PQconnectionUsedPassword` for details. + """ + return bool(impl.PQconnectionUsedPassword(self._pgconn_ptr)) + + @property + def ssl_in_use(self) -> bool: + return self._call_bool(impl.PQsslInUse) + + def exec_(self, command: bytes) -> "PGresult": + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + rv = impl.PQexec(self._pgconn_ptr, command) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query(self, command: bytes) -> None: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + self._ensure_pgconn() + if not impl.PQsendQuery(self._pgconn_ptr, command): + raise e.OperationalError(f"sending query failed: {error_message(self)}") + + def exec_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> "PGresult": + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + rv = impl.PQexecParams(*args) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_query_params( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + args = self._query_params_args( + command, param_values, param_types, param_formats, result_format + ) + self._ensure_pgconn() + if not impl.PQsendQueryParams(*args): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> None: + atypes: Optional[Array[impl.Oid]] + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes): + raise e.OperationalError( + f"sending query and params failed: {error_message(self)}" + ) + + def send_query_prepared( + self, + name: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> None: + # repurpose this function with a cheeky replacement of query with name, + # drop the param_types from the result + args = self._query_params_args( + name, param_values, None, param_formats, result_format + ) + args = args[:3] + args[4:] + + self._ensure_pgconn() + if not impl.PQsendQueryPrepared(*args): + raise e.OperationalError( + f"sending prepared query failed: {error_message(self)}" + ) + + def _query_params_args( + self, + command: bytes, + param_values: Optional[Sequence[Optional["abc.Buffer"]]], + param_types: Optional[Sequence[int]] = None, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, + ) -> Any: + if not isinstance(command, bytes): + raise TypeError(f"bytes expected, got {type(command)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + atypes: Optional[Array[impl.Oid]] + if not param_types: + atypes = None + else: + if len(param_types) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_types)) + ) + atypes = (impl.Oid * nparams)(*param_types) + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_formats" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + return ( + self._pgconn_ptr, + command, + nparams, + atypes, + aparams, + alenghts, + aformats, + result_format, + ) + + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[Sequence[int]] = None, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + if not isinstance(command, bytes): + raise TypeError(f"'command' must be bytes, got {type(command)} instead") + + if not param_types: + nparams = 0 + atypes = None + else: + nparams = len(param_types) + atypes = (impl.Oid * nparams)(*param_types) + + self._ensure_pgconn() + rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def exec_prepared( + self, + name: bytes, + param_values: Optional[Sequence["abc.Buffer"]], + param_formats: Optional[Sequence[int]] = None, + result_format: int = 0, + ) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + + aparams: Optional[Array[c_char_p]] + alenghts: Optional[Array[c_int]] + if param_values: + nparams = len(param_values) + aparams = (c_char_p * nparams)( + *( + # convert bytearray/memoryview to bytes + b if b is None or isinstance(b, bytes) else bytes(b) + for b in param_values + ) + ) + alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values)) + else: + nparams = 0 + aparams = alenghts = None + + if not param_formats: + aformats = None + else: + if len(param_formats) != nparams: + raise ValueError( + "got %d param_values but %d param_types" + % (nparams, len(param_formats)) + ) + aformats = (c_int * nparams)(*param_formats) + + self._ensure_pgconn() + rv = impl.PQexecPrepared( + self._pgconn_ptr, + name, + nparams, + aparams, + alenghts, + aformats, + result_format, + ) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def describe_prepared(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePrepared(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_prepared(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePrepared(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe prepared failed: {error_message(self)}" + ) + + def describe_portal(self, name: bytes) -> "PGresult": + if not isinstance(name, bytes): + raise TypeError(f"'name' must be bytes, got {type(name)} instead") + self._ensure_pgconn() + rv = impl.PQdescribePortal(self._pgconn_ptr, name) + if not rv: + raise MemoryError("couldn't allocate PGresult") + return PGresult(rv) + + def send_describe_portal(self, name: bytes) -> None: + if not isinstance(name, bytes): + raise TypeError(f"bytes expected, got {type(name)} instead") + self._ensure_pgconn() + if not impl.PQsendDescribePortal(self._pgconn_ptr, name): + raise e.OperationalError( + f"sending describe portal failed: {error_message(self)}" + ) + + def get_result(self) -> Optional["PGresult"]: + rv = impl.PQgetResult(self._pgconn_ptr) + return PGresult(rv) if rv else None + + def consume_input(self) -> None: + if 1 != impl.PQconsumeInput(self._pgconn_ptr): + raise e.OperationalError(f"consuming input failed: {error_message(self)}") + + def is_busy(self) -> int: + return impl.PQisBusy(self._pgconn_ptr) + + @property + def nonblocking(self) -> int: + return impl.PQisnonblocking(self._pgconn_ptr) + + @nonblocking.setter + def nonblocking(self, arg: int) -> None: + if 0 > impl.PQsetnonblocking(self._pgconn_ptr, arg): + raise e.OperationalError( + f"setting nonblocking failed: {error_message(self)}" + ) + + def flush(self) -> int: + # PQflush segfaults if it receives a NULL connection + if not self._pgconn_ptr: + raise e.OperationalError("flushing failed: the connection is closed") + rv: int = impl.PQflush(self._pgconn_ptr) + if rv < 0: + raise e.OperationalError(f"flushing failed: {error_message(self)}") + return rv + + def set_single_row_mode(self) -> None: + if not impl.PQsetSingleRowMode(self._pgconn_ptr): + raise e.OperationalError("setting single row mode failed") + + def get_cancel(self) -> "PGcancel": + """ + Create an object with the information needed to cancel a command. + + See :pq:`PQgetCancel` for details. + """ + rv = impl.PQgetCancel(self._pgconn_ptr) + if not rv: + raise e.OperationalError("couldn't create cancel object") + return PGcancel(rv) + + def notifies(self) -> Optional[PGnotify]: + ptr = impl.PQnotifies(self._pgconn_ptr) + if ptr: + c = ptr.contents + return PGnotify(c.relname, c.be_pid, c.extra) + impl.PQfreemem(ptr) + else: + return None + + def put_copy_data(self, buffer: "abc.Buffer") -> int: + if not isinstance(buffer, bytes): + buffer = bytes(buffer) + rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer)) + if rv < 0: + raise e.OperationalError(f"sending copy data failed: {error_message(self)}") + return rv + + def put_copy_end(self, error: Optional[bytes] = None) -> int: + rv = impl.PQputCopyEnd(self._pgconn_ptr, error) + if rv < 0: + raise e.OperationalError(f"sending copy end failed: {error_message(self)}") + return rv + + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: + buffer_ptr = c_char_p() + nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_) + if nbytes == -2: + raise e.OperationalError( + f"receiving copy data failed: {error_message(self)}" + ) + if buffer_ptr: + # TODO: do it without copy + data = string_at(buffer_ptr, nbytes) + impl.PQfreemem(buffer_ptr) + return nbytes, memoryview(data) + else: + return nbytes, memoryview(b"") + + def trace(self, fileno: int) -> None: + """ + Enable tracing of the client/server communication to a file stream. + + See :pq:`PQtrace` for details. + """ + if sys.platform != "linux": + raise e.NotSupportedError("currently only supported on Linux") + stream = impl.fdopen(fileno, b"w") + impl.PQtrace(self._pgconn_ptr, stream) + + def set_trace_flags(self, flags: Trace) -> None: + """ + Configure tracing behavior of client/server communication. + + :param flags: operating mode of tracing. + + See :pq:`PQsetTraceFlags` for details. + """ + impl.PQsetTraceFlags(self._pgconn_ptr, flags) + + def untrace(self) -> None: + """ + Disable tracing, previously enabled through `trace()`. + + See :pq:`PQuntrace` for details. + """ + impl.PQuntrace(self._pgconn_ptr) + + def encrypt_password( + self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None + ) -> bytes: + """ + Return the encrypted form of a PostgreSQL password. + + See :pq:`PQencryptPasswordConn` for details. + """ + out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm) + if not out: + raise e.OperationalError( + f"password encryption failed: {error_message(self)}" + ) + + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def make_empty_result(self, exec_status: int) -> "PGresult": + rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status) + if not rv: + raise MemoryError("couldn't allocate empty PGresult") + return PGresult(rv) + + @property + def pipeline_status(self) -> int: + if version() < 140000: + return 0 + return impl.PQpipelineStatus(self._pgconn_ptr) + + def enter_pipeline_mode(self) -> None: + """Enter pipeline mode. + + :raises ~e.OperationalError: in case of failure to enter the pipeline + mode. + """ + if impl.PQenterPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError("failed to enter pipeline mode") + + def exit_pipeline_mode(self) -> None: + """Exit pipeline mode. + + :raises ~e.OperationalError: in case of failure to exit the pipeline + mode. + """ + if impl.PQexitPipelineMode(self._pgconn_ptr) != 1: + raise e.OperationalError(error_message(self)) + + def pipeline_sync(self) -> None: + """Mark a synchronization point in a pipeline. + + :raises ~e.OperationalError: if the connection is not in pipeline mode + or if sync failed. + """ + rv = impl.PQpipelineSync(self._pgconn_ptr) + if rv == 0: + raise e.OperationalError("connection not in pipeline mode") + if rv != 1: + raise e.OperationalError("failed to sync pipeline") + + def send_flush_request(self) -> None: + """Sends a request for the server to flush its output buffer. + + :raises ~e.OperationalError: if the flush request failed. + """ + if impl.PQsendFlushRequest(self._pgconn_ptr) == 0: + raise e.OperationalError(f"flush request failed: {error_message(self)}") + + def _call_bytes( + self, func: Callable[[impl.PGconn_struct], Optional[bytes]] + ) -> bytes: + """ + Call one of the pgconn libpq functions returning a bytes pointer. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + rv = func(self._pgconn_ptr) + assert rv is not None + return rv + + def _call_int(self, func: Callable[[impl.PGconn_struct], int]) -> int: + """ + Call one of the pgconn libpq functions returning an int. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return func(self._pgconn_ptr) + + def _call_bool(self, func: Callable[[impl.PGconn_struct], int]) -> bool: + """ + Call one of the pgconn libpq functions returning a logical value. + """ + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + return bool(func(self._pgconn_ptr)) + + def _ensure_pgconn(self) -> None: + if not self._pgconn_ptr: + raise e.OperationalError("the connection is closed") + + +class PGresult: + """ + Python representation of a libpq result. + """ + + __slots__ = ("_pgresult_ptr",) + + def __init__(self, pgresult_ptr: impl.PGresult_struct): + self._pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr + + def __del__(self) -> None: + self.clear() + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + status = ExecStatus(self.status) + return f"<{cls} [{status.name}] at 0x{id(self):x}>" + + def clear(self) -> None: + self._pgresult_ptr, p = None, self._pgresult_ptr + if p: + PQclear(p) + + @property + def pgresult_ptr(self) -> Optional[int]: + """The pointer to the underlying `!PGresult` structure, as integer. + + `!None` if the result was cleared. + + The value can be used to pass the structure to libpq functions which + psycopg doesn't (currently) wrap, either in C or in Python using FFI + libraries such as `ctypes`. + """ + if self._pgresult_ptr is None: + return None + + return addressof(self._pgresult_ptr.contents) # type: ignore[attr-defined] + + @property + def status(self) -> int: + return impl.PQresultStatus(self._pgresult_ptr) + + @property + def error_message(self) -> bytes: + return impl.PQresultErrorMessage(self._pgresult_ptr) + + def error_field(self, fieldcode: int) -> Optional[bytes]: + return impl.PQresultErrorField(self._pgresult_ptr, fieldcode) + + @property + def ntuples(self) -> int: + return impl.PQntuples(self._pgresult_ptr) + + @property + def nfields(self) -> int: + return impl.PQnfields(self._pgresult_ptr) + + def fname(self, column_number: int) -> Optional[bytes]: + return impl.PQfname(self._pgresult_ptr, column_number) + + def ftable(self, column_number: int) -> int: + return impl.PQftable(self._pgresult_ptr, column_number) + + def ftablecol(self, column_number: int) -> int: + return impl.PQftablecol(self._pgresult_ptr, column_number) + + def fformat(self, column_number: int) -> int: + return impl.PQfformat(self._pgresult_ptr, column_number) + + def ftype(self, column_number: int) -> int: + return impl.PQftype(self._pgresult_ptr, column_number) + + def fmod(self, column_number: int) -> int: + return impl.PQfmod(self._pgresult_ptr, column_number) + + def fsize(self, column_number: int) -> int: + return impl.PQfsize(self._pgresult_ptr, column_number) + + @property + def binary_tuples(self) -> int: + return impl.PQbinaryTuples(self._pgresult_ptr) + + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: + length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number) + if length: + v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number) + return string_at(v, length) + else: + if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number): + return None + else: + return b"" + + @property + def nparams(self) -> int: + return impl.PQnparams(self._pgresult_ptr) + + def param_type(self, param_number: int) -> int: + return impl.PQparamtype(self._pgresult_ptr, param_number) + + @property + def command_status(self) -> Optional[bytes]: + return impl.PQcmdStatus(self._pgresult_ptr) + + @property + def command_tuples(self) -> Optional[int]: + rv = impl.PQcmdTuples(self._pgresult_ptr) + return int(rv) if rv else None + + @property + def oid_value(self) -> int: + return impl.PQoidValue(self._pgresult_ptr) + + def set_attributes(self, descriptions: List[PGresAttDesc]) -> None: + structs = [ + impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore + ] + array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore + rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) + if rv == 0: + raise e.OperationalError("PQsetResultAttrs failed") + + +class PGcancel: + """ + Token to cancel the current operation on a connection. + + Created by `PGconn.get_cancel()`. + """ + + __slots__ = ("pgcancel_ptr",) + + def __init__(self, pgcancel_ptr: impl.PGcancel_struct): + self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr + + def __del__(self) -> None: + self.free() + + def free(self) -> None: + """ + Free the data structure created by :pq:`PQgetCancel()`. + + Automatically invoked by `!__del__()`. + + See :pq:`PQfreeCancel()` for details. + """ + self.pgcancel_ptr, p = None, self.pgcancel_ptr + if p: + PQfreeCancel(p) + + def cancel(self) -> None: + """Requests that the server abandon processing of the current command. + + See :pq:`PQcancel()` for details. + """ + buf = create_string_buffer(256) + res = impl.PQcancel( + self.pgcancel_ptr, + byref(buf), # type: ignore[arg-type] + len(buf), + ) + if not res: + raise e.OperationalError( + f"cancel failed: {buf.value.decode('utf8', 'ignore')}" + ) + + +class Conninfo: + """ + Utility object to manipulate connection strings. + """ + + @classmethod + def get_defaults(cls) -> List[ConninfoOption]: + opts = impl.PQconndefaults() + if not opts: + raise MemoryError("couldn't allocate connection defaults") + try: + return cls._options_from_array(opts) + finally: + impl.PQconninfoFree(opts) + + @classmethod + def parse(cls, conninfo: bytes) -> List[ConninfoOption]: + if not isinstance(conninfo, bytes): + raise TypeError(f"bytes expected, got {type(conninfo)} instead") + + errmsg = c_char_p() + rv = impl.PQconninfoParse(conninfo, byref(errmsg)) # type: ignore[arg-type] + if not rv: + if not errmsg: + raise MemoryError("couldn't allocate on conninfo parse") + else: + exc = e.OperationalError( + (errmsg.value or b"").decode("utf8", "replace") + ) + impl.PQfreemem(errmsg) + raise exc + + try: + return cls._options_from_array(rv) + finally: + impl.PQconninfoFree(rv) + + @classmethod + def _options_from_array( + cls, opts: Sequence[impl.PQconninfoOption_struct] + ) -> List[ConninfoOption]: + rv = [] + skws = "keyword envvar compiled val label dispchar".split() + for opt in opts: + if not opt.keyword: + break + d = {kw: getattr(opt, kw) for kw in skws} + d["dispsize"] = opt.dispsize + rv.append(ConninfoOption(**d)) + + return rv + + +class Escaping: + """ + Utility object to escape strings for SQL interpolation. + """ + + def __init__(self, conn: Optional[PGconn] = None): + self.conn = conn + + def escape_literal(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_literal failed: no connection provided") + + self.conn._ensure_pgconn() + # TODO: might be done without copy (however C does that) + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_literal failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_identifier(self, data: "abc.Buffer") -> bytes: + if not self.conn: + raise e.OperationalError("escape_identifier failed: no connection provided") + + self.conn._ensure_pgconn() + + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data)) + if not out: + raise e.OperationalError( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + def escape_string(self, data: "abc.Buffer") -> bytes: + if not isinstance(data, bytes): + data = bytes(data) + + if self.conn: + self.conn._ensure_pgconn() + error = c_int() + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeStringConn( + self.conn._pgconn_ptr, + byref(out), # type: ignore[arg-type] + data, + len(data), + byref(error), # type: ignore[arg-type] + ) + + if error: + raise e.OperationalError( + f"escape_string failed: {error_message(self.conn)} bytes" + ) + + else: + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeString( + byref(out), # type: ignore[arg-type] + data, + len(data), + ) + + return out.value + + def escape_bytea(self, data: "abc.Buffer") -> bytes: + len_out = c_size_t() + # TODO: might be able to do without a copy but it's a mess. + # the C library does it better anyway, so maybe not worth optimising + # https://mail.python.org/pipermail/python-dev/2012-September/121780.html + if not isinstance(data, bytes): + data = bytes(data) + if self.conn: + self.conn._ensure_pgconn() + out = impl.PQescapeByteaConn( + self.conn._pgconn_ptr, + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + else: + out = impl.PQescapeBytea( + data, + len(data), + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value - 1) # out includes final 0 + impl.PQfreemem(out) + return rv + + def unescape_bytea(self, data: "abc.Buffer") -> bytes: + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn: + self.conn._ensure_pgconn() + + len_out = c_size_t() + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQunescapeBytea( + data, + byref(t_cast(c_ulong, len_out)), # type: ignore[arg-type] + ) + if not out: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value) + impl.PQfreemem(out) + return rv + + +# importing the ssl module sets up Python's libcrypto callbacks +import ssl # noqa + +# disable libcrypto setup in libpq, so it won't stomp on the callbacks +# that have already been set up +impl.PQinitOpenSSL(1, 0) + +__build_version__ = version() diff --git a/psycopg/psycopg/py.typed b/psycopg/psycopg/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/psycopg/psycopg/py.typed diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py new file mode 100644 index 0000000..cb28b57 --- /dev/null +++ b/psycopg/psycopg/rows.py @@ -0,0 +1,256 @@ +""" +psycopg row factories +""" + +# Copyright (C) 2021 The Psycopg Team + +import functools +from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn +from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar +from collections import namedtuple +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from ._compat import Protocol +from ._encodings import _as_python_identifier + +if TYPE_CHECKING: + from .cursor import BaseCursor, Cursor + from .cursor_async import AsyncCursor + from psycopg.pq.abc import PGresult + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE + +T = TypeVar("T", covariant=True) + +# Row factories + +Row = TypeVar("Row", covariant=True) + + +class RowMaker(Protocol[Row]): + """ + Callable protocol taking a sequence of value and returning an object. + + The sequence of value is what is returned from a database query, already + adapted to the right Python types. The return value is the object that your + program would like to receive: by default (`tuple_row()`) it is a simple + tuple, but it may be any type of object. + + Typically, `!RowMaker` functions are returned by `RowFactory`. + """ + + def __call__(self, __values: Sequence[Any]) -> Row: + ... + + +class RowFactory(Protocol[Row]): + """ + Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`. + + A `!RowFactory` is typically called when a `!Cursor` receives a result. + This way it can inspect the cursor state (for instance the + `~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create + a complete object. + + For instance the `dict_row()` `!RowFactory` uses the names of the column to + define the dictionary key and returns a `!RowMaker` function which would + use the values to create a dictionary for each record. + """ + + def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: + ... + + +class AsyncRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking an async cursor as argument. + """ + + def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: + ... + + +class BaseRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking either type of cursor as argument. + """ + + def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: + ... + + +TupleRow: TypeAlias = Tuple[Any, ...] +""" +An alias for the type returned by `tuple_row()` (i.e. a tuple of any content). +""" + + +DictRow: TypeAlias = Dict[str, Any] +""" +An alias for the type returned by `dict_row()` + +A `!DictRow` is a dictionary with keys as string and any value returned by the +database. +""" + + +def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]": + r"""Row factory to represent rows as simple tuples. + + This is the default factory, used when `~psycopg.Connection.connect()` or + `~psycopg.Connection.cursor()` are called without a `!row_factory` + parameter. + + """ + # Implementation detail: make sure this is the tuple type itself, not an + # equivalent function, because the C code fast-paths on it. + return tuple + + +def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]": + """Row factory to represent rows as dictionaries. + + The dictionary keys are taken from the column names of the returned columns. + """ + names = _get_names(cursor) + if names is None: + return no_result + + def dict_row_(values: Sequence[Any]) -> Dict[str, Any]: + # https://github.com/python/mypy/issues/2608 + return dict(zip(names, values)) # type: ignore[arg-type] + + return dict_row_ + + +def namedtuple_row( + cursor: "BaseCursor[Any, Any]", +) -> "RowMaker[NamedTuple]": + """Row factory to represent rows as `~collections.namedtuple`. + + The field names are taken from the column names of the returned columns, + with some mangling to deal with invalid names. + """ + res = cursor.pgresult + if not res: + return no_result + + nfields = _get_nfields(res) + if nfields is None: + return no_result + + nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields))) + return nt._make + + +@functools.lru_cache(512) +def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]: + snames = tuple(_as_python_identifier(n.decode(enc)) for n in names) + return namedtuple("Row", snames) # type: ignore[return-value] + + +def class_row(cls: Type[T]) -> BaseRowFactory[T]: + r"""Generate a row factory to represent rows as instances of the class `!cls`. + + The class must support every output column name as a keyword parameter. + + :param cls: The class to return for each row. It must support the fields + returned by the query as keyword arguments. + :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]] + """ + + def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def class_row__(values: Sequence[Any]) -> T: + return cls(**dict(zip(names, values))) # type: ignore[arg-type] + + return class_row__ + + return class_row_ + + +def args_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with positional parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as positional arguments. + """ + + def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]": + def args_row__(values: Sequence[Any]) -> T: + return func(*values) + + return args_row__ + + return args_row_ + + +def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling `!func` with keyword parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as keyword arguments. + """ + + def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]": + names = _get_names(cursor) + if names is None: + return no_result + + def kwargs_row__(values: Sequence[Any]) -> T: + return func(**dict(zip(names, values))) # type: ignore[arg-type] + + return kwargs_row__ + + return kwargs_row_ + + +def no_result(values: Sequence[Any]) -> NoReturn: + """A `RowMaker` that always fail. + + It can be used as return value for a `RowFactory` called with no result. + Note that the `!RowFactory` *will* be called with no result, but the + resulting `!RowMaker` never should. + """ + raise e.InterfaceError("the cursor doesn't have a result") + + +def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]: + res = cursor.pgresult + if not res: + return None + + nfields = _get_nfields(res) + if nfields is None: + return None + + enc = cursor._encoding + return [ + res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr] + ] + + +def _get_nfields(res: "PGresult") -> Optional[int]: + """ + Return the number of columns in a result, if it returns tuples else None + + Take into account the special case of results with zero columns. + """ + nfields = res.nfields + + if ( + res.status == TUPLES_OK + or res.status == SINGLE_TUPLE + # "describe" in named cursors + or (res.status == COMMAND_OK and nfields) + ): + return nfields + else: + return None diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py new file mode 100644 index 0000000..b890d77 --- /dev/null +++ b/psycopg/psycopg/server_cursor.py @@ -0,0 +1,479 @@ +""" +psycopg server-side cursor objects. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, AsyncIterator, List, Iterable, Iterator +from typing import Optional, TypeVar, TYPE_CHECKING, overload +from warnings import warn + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, Query, Params, PQGen +from .rows import Row, RowFactory, AsyncRowFactory +from .cursor import BaseCursor, Cursor +from .generators import execute +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from .connection import Connection + from .connection_async import AsyncConnection + +DEFAULT_ITERSIZE = 100 + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + + +class ServerCursorMixin(BaseCursor[ConnectionType, Row]): + """Mixin to add ServerCursor behaviour and implementation a BaseCursor.""" + + __slots__ = "_name _scrollable _withhold _described itersize _format".split() + + def __init__( + self, + name: str, + scrollable: Optional[bool], + withhold: bool, + ): + self._name = name + self._scrollable = scrollable + self._withhold = withhold + self._described = False + self.itersize: int = DEFAULT_ITERSIZE + self._format = TEXT + + def __repr__(self) -> str: + # Insert the name as the second word + parts = super().__repr__().split(None, 1) + parts.insert(1, f"{self._name!r}") + return " ".join(parts) + + @property + def name(self) -> str: + """The name of the cursor.""" + return self._name + + @property + def scrollable(self) -> Optional[bool]: + """ + Whether the cursor is scrollable or not. + + If `!None` leave the choice to the server. Use `!True` if you want to + use `scroll()` on the cursor. + """ + return self._scrollable + + @property + def withhold(self) -> bool: + """ + If the cursor can be used after the creating transaction has committed. + """ + return self._withhold + + @property + def rownumber(self) -> Optional[int]: + """Index of the next row to fetch in the current result. + + `!None` if there is no result to fetch. + """ + res = self.pgresult + # command_status is empty if the result comes from + # describe_portal, which means that we have just executed the DECLARE, + # so we can assume we are at the first row. + tuples = res and (res.status == TUPLES_OK or res.command_status == b"") + return self._pos if tuples else None + + def _declare_gen( + self, + query: Query, + params: Optional[Params] = None, + binary: Optional[bool] = None, + ) -> PQGen[None]: + """Generator implementing `ServerCursor.execute()`.""" + + query = self._make_declare_statement(query) + + # If the cursor is being reused, the previous one must be closed. + if self._described: + yield from self._close_gen() + self._described = False + + yield from self._start_query(query) + pgq = self._convert_query(query, params) + self._execute_send(pgq, force_extended=True) + results = yield from execute(self._conn.pgconn) + if results[-1].status != COMMAND_OK: + self._raise_for_result(results[-1]) + + # Set the format, which will be used by describe and fetch operations + if binary is None: + self._format = self.format + else: + self._format = BINARY if binary else TEXT + + # The above result only returned COMMAND_OK. Get the cursor shape + yield from self._describe_gen() + + def _describe_gen(self) -> PQGen[None]: + self._pgconn.send_describe_portal(self._name.encode(self._encoding)) + results = yield from execute(self._pgconn) + self._check_results(results) + self._results = results + self._select_current_result(0, format=self._format) + self._described = True + + def _close_gen(self) -> PQGen[None]: + ts = self._conn.pgconn.transaction_status + + # if the connection is not in a sane state, don't even try + if ts != IDLE and ts != INTRANS: + return + + # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already. + if not self._withhold and ts == IDLE: + return + + # if we didn't declare the cursor ourselves we still have to close it + # but we must make sure it exists. + if not self._described: + query = sql.SQL( + "SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}" + ).format(sql.Literal(self._name)) + res = yield from self._conn._exec_command(query) + # pipeline mode otherwise, unsupported here. + assert res is not None + if res.ntuples == 0: + return + + query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name)) + yield from self._conn._exec_command(query) + + def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]: + if self.closed: + raise e.InterfaceError("the cursor is closed") + # If we are stealing the cursor, make sure we know its shape + if not self._described: + yield from self._start_query() + yield from self._describe_gen() + + query = sql.SQL("FETCH FORWARD {} FROM {}").format( + sql.SQL("ALL") if num is None else sql.Literal(num), + sql.Identifier(self._name), + ) + res = yield from self._conn._exec_command(query, result_format=self._format) + # pipeline mode otherwise, unsupported here. + assert res is not None + + self.pgresult = res + self._tx.set_pgresult(res, set_loaders=False) + return self._tx.load_rows(0, res.ntuples, self._make_row) + + def _scroll_gen(self, value: int, mode: str) -> PQGen[None]: + if mode not in ("relative", "absolute"): + raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'") + query = sql.SQL("MOVE{} {} FROM {}").format( + sql.SQL(" ABSOLUTE" if mode == "absolute" else ""), + sql.Literal(value), + sql.Identifier(self._name), + ) + yield from self._conn._exec_command(query) + + def _make_declare_statement(self, query: Query) -> sql.Composed: + + if isinstance(query, bytes): + query = query.decode(self._encoding) + if not isinstance(query, sql.Composable): + query = sql.SQL(query) + + parts = [ + sql.SQL("DECLARE"), + sql.Identifier(self._name), + ] + if self._scrollable is not None: + parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL")) + parts.append(sql.SQL("CURSOR")) + if self._withhold: + parts.append(sql.SQL("WITH HOLD")) + parts.append(sql.SQL("FOR")) + parts.append(query) + + return sql.SQL(" ").join(parts) + + +class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="ServerCursor[Any]") + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "ServerCursor[Row]", + connection: "Connection[Any]", + name: str, + *, + row_factory: RowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "Connection[Any]", + name: str, + *, + row_factory: Optional[RowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + Cursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + def close(self) -> None: + """ + Close the current cursor and free associated resources. + """ + with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + self._conn.wait(self._close_gen()) + super().close() + + def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + """ + Open a cursor to execute a query to the database. + """ + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + with self._conn.lock: + self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + """Method not implemented for server-side cursors.""" + raise e.NotSupportedError("executemany not supported on server-side cursors") + + def fetchone(self) -> Optional[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + def fetchall(self) -> List[Row]: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + def __iter__(self) -> Iterator[Row]: + while True: + with self._conn.lock: + recs = self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + def scroll(self, value: int, mode: str = "relative") -> None: + with self._conn.lock: + self._conn.wait(self._scroll_gen(value, mode)) + # Postgres doesn't have a reliable way to report a cursor out of bound + if mode == "relative": + self._pos += value + else: + self._pos = value + + +class AsyncServerCursor( + ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row] +): + __module__ = "psycopg" + __slots__ = () + _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]") + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Row]", + name: str, + *, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + @overload + def __init__( + self: "AsyncServerCursor[Row]", + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: AsyncRowFactory[Row], + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + ... + + def __init__( + self, + connection: "AsyncConnection[Any]", + name: str, + *, + row_factory: Optional[AsyncRowFactory[Row]] = None, + scrollable: Optional[bool] = None, + withhold: bool = False, + ): + AsyncCursor.__init__( + self, connection, row_factory=row_factory or connection.row_factory + ) + ServerCursorMixin.__init__(self, name, scrollable, withhold) + + def __del__(self) -> None: + if not self.closed: + warn( + f"the server-side cursor {self} was deleted while still open." + " Please use 'with' or '.close()' to close the cursor properly", + ResourceWarning, + ) + + async def close(self) -> None: + async with self._conn.lock: + if self.closed: + return + if not self._conn.closed: + await self._conn.wait(self._close_gen()) + await super().close() + + async def execute( + self: _Self, + query: Query, + params: Optional[Params] = None, + *, + binary: Optional[bool] = None, + **kwargs: Any, + ) -> _Self: + if kwargs: + raise TypeError(f"keyword not supported: {list(kwargs)[0]}") + if self._pgconn.pipeline_status: + raise e.NotSupportedError( + "server-side cursors not supported in pipeline mode" + ) + + try: + async with self._conn.lock: + await self._conn.wait(self._declare_gen(query, params, binary)) + except e.Error as ex: + raise ex.with_traceback(None) + + return self + + async def executemany( + self, + query: Query, + params_seq: Iterable[Params], + *, + returning: bool = True, + ) -> None: + raise e.NotSupportedError("executemany not supported on server-side cursors") + + async def fetchone(self) -> Optional[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(1)) + if recs: + self._pos += 1 + return recs[0] + else: + return None + + async def fetchmany(self, size: int = 0) -> List[Row]: + if not size: + size = self.arraysize + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(size)) + self._pos += len(recs) + return recs + + async def fetchall(self) -> List[Row]: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(None)) + self._pos += len(recs) + return recs + + async def __aiter__(self) -> AsyncIterator[Row]: + while True: + async with self._conn.lock: + recs = await self._conn.wait(self._fetch_gen(self.itersize)) + for rec in recs: + self._pos += 1 + yield rec + if len(recs) < self.itersize: + break + + async def scroll(self, value: int, mode: str = "relative") -> None: + async with self._conn.lock: + await self._conn.wait(self._scroll_gen(value, mode)) diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py new file mode 100644 index 0000000..099a01c --- /dev/null +++ b/psycopg/psycopg/sql.py @@ -0,0 +1,467 @@ +""" +SQL composition utility module +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +import string +from abc import ABC, abstractmethod +from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union + +from .pq import Escaping +from .abc import AdaptContext +from .adapt import Transformer, PyFormat +from ._compat import LiteralString +from ._encodings import conn_encoding + + +def quote(obj: Any, context: Optional[AdaptContext] = None) -> str: + """ + Adapt a Python object to a quoted SQL string. + + Use this function only if you absolutely want to convert a Python string to + an SQL quoted literal to use e.g. to generate batch SQL and you won't have + a connection available when you will need to use it. + + This function is relatively inefficient, because it doesn't cache the + adaptation rules. If you pass a `!context` you can adapt the adaptation + rules used, otherwise only global rules are used. + + """ + return Literal(obj).as_string(context) + + +class Composable(ABC): + """ + Abstract base class for objects that can be used to compose an SQL string. + + `!Composable` objects can be passed directly to + `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`, + `~psycopg.Cursor.copy()` in place of the query string. + + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. + """ + + def __init__(self, obj: Any): + self._obj = obj + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._obj!r})" + + @abstractmethod + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + """ + Return the value of the object as bytes. + + :param context: the context to evaluate the object into. + :type context: `connection` or `cursor` + + The method is automatically invoked by `~psycopg.Cursor.execute()`, + `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a + `!Composable` is passed instead of the query string. + + """ + raise NotImplementedError + + def as_string(self, context: Optional[AdaptContext]) -> str: + """ + Return the value of the object as string. + + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` + + """ + conn = context.connection if context else None + enc = conn_encoding(conn) + b = self.as_bytes(context) + if isinstance(b, bytes): + return b.decode(enc) + else: + # buffer object + return codecs.lookup(enc).decode(b)[0] + + def __add__(self, other: "Composable") -> "Composed": + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composable): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + def __mul__(self, n: int) -> "Composed": + return Composed([self] * n) + + def __eq__(self, other: Any) -> bool: + return type(self) is type(other) and self._obj == other._obj + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Composed(Composable): + """ + A `Composable` object made of a sequence of `!Composable`. + + The object is usually created using `!Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of objects as arguments: if they are not `!Composable` they will + be wrapped in a `Literal`. + + Example:: + + >>> comp = sql.Composed( + ... [sql.SQL("INSERT INTO "), sql.Identifier("table")]) + >>> print(comp.as_string(conn)) + INSERT INTO "table" + + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). + """ + + _obj: List[Composable] + + def __init__(self, seq: Sequence[Any]): + seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq] + super().__init__(seq) + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + return b"".join(obj.as_bytes(context) for obj in self._obj) + + def __iter__(self) -> Iterator[Composable]: + return iter(self._obj) + + def __add__(self, other: Composable) -> "Composed": + if isinstance(other, Composed): + return Composed(self._obj + other._obj) + if isinstance(other, Composable): + return Composed(self._obj + [other]) + else: + return NotImplemented + + def join(self, joiner: Union["SQL", LiteralString]) -> "Composed": + """ + Return a new `!Composed` interposing the `!joiner` with the `!Composed` items. + + The `!joiner` must be a `SQL` or a string which will be interpreted as + an `SQL`. + + Example:: + + >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed + >>> print(fields.join(', ').as_string(conn)) + "foo", "bar" + + """ + if isinstance(joiner, str): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be strings or SQL," + f" got {joiner!r} instead" + ) + + return joiner.join(self._obj) + + +class SQL(Composable): + """ + A `Composable` representing a snippet of SQL statement. + + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). + + The `!obj` string doesn't undergo any form of escaping, so it is not + suitable to represent variable identifiers or values: you should only use + it to pass constant strings representing templates or snippets of SQL + statements; use other objects such as `Identifier` or `Literal` to + represent variable parts. + + Example:: + + >>> query = sql.SQL("SELECT {0} FROM {1}").format( + ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), + ... sql.Identifier('table')) + >>> print(query.as_string(conn)) + SELECT "foo", "bar" FROM "table" + """ + + _obj: LiteralString + _formatter = string.Formatter() + + def __init__(self, obj: LiteralString): + super().__init__(obj) + if not isinstance(obj, str): + raise TypeError(f"SQL values must be strings, got {obj!r} instead") + + def as_string(self, context: Optional[AdaptContext]) -> str: + return self._obj + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + enc = "utf-8" + if context: + enc = conn_encoding(context.connection) + return self._obj.encode(enc) + + def format(self, *args: Any, **kwargs: Any) -> Composed: + """ + Merge `Composable` objects into a template. + + :param args: parameters to replace to numbered (``{0}``, ``{1}``) or + auto-numbered (``{}``) placeholders + :param kwargs: parameters to replace to named (``{name}``) placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``), numbered (``{0}``, + ``{1}``...), and named placeholders (``{name}``), with positional + arguments replacing the numbered placeholders and keywords replacing + the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) + are not supported. + + If a `!Composable` objects is passed to the template it will be merged + according to its `as_string()` method. If any other Python object is + passed, it will be wrapped in a `Literal` object and so escaped + according to SQL rules. + + Example:: + + >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + SELECT * FROM "people" WHERE "id" = %s + + >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}") + ... .format(tbl=sql.Identifier('people'), name="O'Rourke")) + ... .as_string(conn)) + SELECT * FROM "people" WHERE name = 'O''Rourke' + + """ + rv: List[Composable] = [] + autonum: Optional[int] = 0 + # TODO: this is probably not the right way to whitelist pre + # pyre complains. Will wait for mypy to complain too to fix. + pre: LiteralString + for pre, name, spec, conv in self._formatter.parse(self._obj): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual" + ) + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic" + ) + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + + def join(self, seq: Iterable[Composable]) -> Composed: + """ + Join a sequence of `Composable`. + + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's string to separate the elements in `!seq`. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. + + Example:: + + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) + >>> print(snip.as_string(conn)) + "foo", "bar", "baz" + """ + rv = [] + it = iter(seq) + try: + rv.append(next(it)) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composable): + """ + A `Composable` representing an SQL identifier or a dot-separated sequence. + + Identifiers usually represent names of database objects, such as tables or + fields. PostgreSQL identifiers follow `different rules`__ than SQL string + literals for escaping (e.g. they use double quotes instead of single). + + .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \ + SQL-SYNTAX-IDENTIFIERS + + Example:: + + >>> t1 = sql.Identifier("foo") + >>> t2 = sql.Identifier("ba'r") + >>> t3 = sql.Identifier('ba"z') + >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) + "foo", "ba'r", "ba""z" + + Multiple strings can be passed to the object to represent a qualified name, + i.e. a dot-separated sequence of identifiers. + + Example:: + + >>> query = sql.SQL("SELECT {} FROM {}").format( + ... sql.Identifier("table", "field"), + ... sql.Identifier("schema", "table")) + >>> print(query.as_string(conn)) + SELECT "table"."field" FROM "schema"."table" + + """ + + _obj: Sequence[str] + + def __init__(self, *strings: str): + # init super() now to make the __repr__ not explode in case of error + super().__init__(strings) + + if not strings: + raise TypeError("Identifier cannot be empty") + + for s in strings: + if not isinstance(s, str): + raise TypeError( + f"SQL identifier parts must be strings, got {s!r} instead" + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + if not conn: + raise ValueError("a connection is necessary for Identifier") + esc = Escaping(conn.pgconn) + enc = conn_encoding(conn) + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + return b".".join(escs) + + +class Literal(Composable): + """ + A `Composable` representing an SQL value to include in a query. + + Usually you will want to include placeholders in the query and pass values + as `~cursor.execute()` arguments. If however you really really need to + include a literal value in the query you can use this object. + + The string returned by `!as_string()` follows the normal :ref:`adaptation + rules <types-adaptation>` for Python objects. + + Example:: + + >>> s1 = sql.Literal("fo'o") + >>> s2 = sql.Literal(42) + >>> s3 = sql.Literal(date(2000, 1, 1)) + >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) + 'fo''o', 42, '2000-01-01'::date + + """ + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + tx = Transformer.from_context(context) + return tx.as_literal(self._obj) + + +class Placeholder(Composable): + """A `Composable` representing a placeholder for query parameters. + + If the name is specified, generate a named placeholder (e.g. ``%(name)s``, + ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``, + ``%b``). + + The object is useful to generate SQL queries with a variable number of + arguments. + + Examples:: + + >>> names = ['foo', 'bar', 'baz'] + + >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) + >>> print(q1.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s) + + >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) + >>> print(q2.as_string(conn)) + INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s) + + """ + + def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO): + super().__init__(name) + if not isinstance(name, str): + raise TypeError(f"expected string as name, got {name!r}") + + if ")" in name: + raise ValueError(f"invalid name: {name!r}") + + if type(format) is str: + format = PyFormat(format) + if not isinstance(format, PyFormat): + raise TypeError( + f"expected PyFormat as format, got {type(format).__name__!r}" + ) + + self._format: PyFormat = format + + def __repr__(self) -> str: + parts = [] + if self._obj: + parts.append(repr(self._obj)) + if self._format is not PyFormat.AUTO: + parts.append(f"format={self._format.name}") + + return f"{self.__class__.__name__}({', '.join(parts)})" + + def as_string(self, context: Optional[AdaptContext]) -> str: + code = self._format.value + return f"%({self._obj}){code}" if self._obj else f"%{code}" + + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + conn = context.connection if context else None + enc = conn_encoding(conn) + return self.as_string(context).encode(enc) + + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py new file mode 100644 index 0000000..e13486e --- /dev/null +++ b/psycopg/psycopg/transaction.py @@ -0,0 +1,290 @@ +""" +Transaction context managers returned by Connection.transaction() +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from types import TracebackType +from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING + +from . import pq +from . import sql +from . import errors as e +from .abc import ConnectionType, PQGen + +if TYPE_CHECKING: + from typing import Any + from .connection import Connection + from .connection_async import AsyncConnection + +IDLE = pq.TransactionStatus.IDLE + +OK = pq.ConnStatus.OK + +logger = logging.getLogger(__name__) + + +class Rollback(Exception): + """ + Exit the current `Transaction` context immediately and rollback any changes + made within this context. + + If a transaction context is specified in the constructor, rollback + enclosing transactions contexts up to and including the one specified. + """ + + __module__ = "psycopg" + + def __init__( + self, + transaction: Union["Transaction", "AsyncTransaction", None] = None, + ): + self.transaction = transaction + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}({self.transaction!r})" + + +class OutOfOrderTransactionNesting(e.ProgrammingError): + """Out-of-order transaction nesting detected""" + + +class BaseTransaction(Generic[ConnectionType]): + def __init__( + self, + connection: ConnectionType, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ): + self._conn = connection + self.pgconn = self._conn.pgconn + self._savepoint_name = savepoint_name or "" + self.force_rollback = force_rollback + self._entered = self._exited = False + self._outer_transaction = False + self._stack_index = -1 + + @property + def savepoint_name(self) -> Optional[str]: + """ + The name of the savepoint; `!None` if handling the main transaction. + """ + # Yes, it may change on __enter__. No, I don't care, because the + # un-entered state is outside the public interface. + return self._savepoint_name + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = pq.misc.connection_summary(self.pgconn) + if not self._entered: + status = "inactive" + elif not self._exited: + status = "active" + else: + status = "terminated" + + sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" + return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" + + def _enter_gen(self) -> PQGen[None]: + if self._entered: + raise TypeError("transaction blocks can be used only once") + self._entered = True + + self._push_savepoint() + for command in self._get_enter_commands(): + yield from self._conn._exec_command(command) + + def _exit_gen( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> PQGen[bool]: + if not exc_val and not self.force_rollback: + yield from self._commit_gen() + return False + else: + # try to rollback, but if there are problems (connection in a bad + # state) just warn without clobbering the exception bubbling up. + try: + return (yield from self._rollback_gen(exc_val)) + except OutOfOrderTransactionNesting: + # Clobber an exception happened in the block with the exception + # caused by out-of-order transaction detected, so make the + # behaviour consistent with _commit_gen and to make sure the + # user fixes this condition, which is unrelated from + # operational error that might arise in the block. + raise + except Exception as exc2: + logger.warning("error ignored in rollback of %s: %s", self, exc2) + return False + + def _commit_gen(self) -> PQGen[None]: + ex = self._pop_savepoint("commit") + self._exited = True + if ex: + raise ex + + for command in self._get_commit_commands(): + yield from self._conn._exec_command(command) + + def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: + if isinstance(exc_val, Rollback): + logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True) + + ex = self._pop_savepoint("rollback") + self._exited = True + if ex: + raise ex + + for command in self._get_rollback_commands(): + yield from self._conn._exec_command(command) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False + + def _get_enter_commands(self) -> Iterator[bytes]: + if self._outer_transaction: + yield self._conn._get_tx_start_command() + + if self._savepoint_name: + yield ( + sql.SQL("SAVEPOINT {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + def _get_commit_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("RELEASE {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"COMMIT" + + def _get_rollback_commands(self) -> Iterator[bytes]: + if self._savepoint_name and not self._outer_transaction: + yield ( + sql.SQL("ROLLBACK TO {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + yield ( + sql.SQL("RELEASE {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._num_transactions + yield b"ROLLBACK" + + # Also clear the prepared statements cache. + if self._conn._prepared.clear(): + yield from self._conn._prepared.get_maintenance_commands() + + def _push_savepoint(self) -> None: + """ + Push the transaction on the connection transactions stack. + + Also set the internal state of the object and verify consistency. + """ + self._outer_transaction = self.pgconn.transaction_status == IDLE + if self._outer_transaction: + # outer transaction: if no name it's only a begin, else + # there will be an additional savepoint + assert not self._conn._num_transactions + else: + # inner transaction: it always has a name + if not self._savepoint_name: + self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}" + + self._stack_index = self._conn._num_transactions + self._conn._num_transactions += 1 + + def _pop_savepoint(self, action: str) -> Optional[Exception]: + """ + Pop the transaction from the connection transactions stack. + + Also verify the state consistency. + """ + self._conn._num_transactions -= 1 + if self._conn._num_transactions == self._stack_index: + return None + + return OutOfOrderTransactionNesting( + f"transaction {action} at the wrong nesting level: {self}" + ) + + +class Transaction(BaseTransaction["Connection[Any]"]): + """ + Returned by `Connection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="Transaction") + + @property + def connection(self) -> "Connection[Any]": + """The connection the object is managing.""" + return self._conn + + def __enter__(self: _Self) -> _Self: + with self._conn.lock: + self._conn.wait(self._enter_gen()) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + with self._conn.lock: + return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False + + +class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): + """ + Returned by `AsyncConnection.transaction()` to handle a transaction block. + """ + + __module__ = "psycopg" + + _Self = TypeVar("_Self", bound="AsyncTransaction") + + @property + def connection(self) -> "AsyncConnection[Any]": + return self._conn + + async def __aenter__(self: _Self) -> _Self: + async with self._conn.lock: + await self._conn.wait(self._enter_gen()) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self.pgconn.status == OK: + async with self._conn.lock: + return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) + else: + return False diff --git a/psycopg/psycopg/types/__init__.py b/psycopg/psycopg/types/__init__.py new file mode 100644 index 0000000..bdddf05 --- /dev/null +++ b/psycopg/psycopg/types/__init__.py @@ -0,0 +1,11 @@ +""" +psycopg types package +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import _typeinfo + +# Exposed here +TypeInfo = _typeinfo.TypeInfo +TypesRegistry = _typeinfo.TypesRegistry diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py new file mode 100644 index 0000000..e35c5e7 --- /dev/null +++ b/psycopg/psycopg/types/array.py @@ -0,0 +1,464 @@ +""" +Adapters for arrays +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type + +from .. import pq +from .. import errors as e +from .. import postgres +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._compat import cache, prod +from .._struct import pack_len, unpack_len +from .._cmodule import _psycopg +from ..postgres import TEXT_OID, INVALID_OID +from .._typeinfo import TypeInfo + +_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid +_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack) +_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from) +_struct_dim = struct.Struct("!II") # dim, lower bound +_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack) +_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from) + +TEXT_ARRAY_OID = postgres.types["text"].array_oid + +PY_TEXT = PyFormat.TEXT +PQ_BINARY = pq.Format.BINARY + + +class BaseListDumper(RecursiveDumper): + element_oid = 0 + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + if cls is NoneType: + cls = list + + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + if self.element_oid and context: + sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format) + self.sub_dumper = sdclass(NoneType, context) + + def _find_list_element(self, L: List[Any], format: PyFormat) -> Any: + """ + Find the first non-null element of an eventually nested list + """ + items = list(self._flatiter(L, set())) + types = {type(item): item for item in items} + if not types: + return None + + if len(types) == 1: + t, v = types.popitem() + else: + # More than one type in the list. It might be still good, as long + # as they dump with the same oid (e.g. IPv4Network, IPv6Network). + dumpers = [self._tx.get_dumper(item, format) for item in types.values()] + oids = set(d.oid for d in dumpers) + if len(oids) == 1: + t, v = types.popitem() + else: + raise e.DataError( + "cannot dump lists of mixed types;" + f" got: {', '.join(sorted(t.__name__ for t in types))}" + ) + + # Checking for precise type. If the type is a subclass (e.g. Int4) + # we assume the user knows what type they are passing. + if t is not int: + return v + + # If we got an int, let's see what is the biggest one in order to + # choose the smallest OID and allow Postgres to do the right cast. + imax: int = max(items) + imin: int = min(items) + if imin >= 0: + return imax + else: + return max(imax, -imin - 1) + + def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: + if id(L) in seen: + raise e.DataError("cannot dump a recursive list") + + seen.add(id(L)) + + for item in L: + if type(item) is list: + yield from self._flatiter(item, seen) + elif item is not None: + yield item + + return None + + def _get_base_type_info(self, base_oid: int) -> TypeInfo: + """ + Return info about the base type. + + Return text info as fallback. + """ + if base_oid: + info = self._tx.adapters.types.get(base_oid) + if info: + return info + + return self._tx.adapters.types["text"] + + +class ListDumper(BaseListDumper): + + delimiter = b"," + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return self.cls + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + # Empty lists can only be dumped as text if the type is unknown. + return self + + sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + + # We consider an array of unknowns as unknown, so we can dump empty + # lists or lists containing only None elements. + if sd.oid != INVALID_OID: + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + dumper.delimiter = info.delimiter.encode() + else: + dumper.oid = INVALID_OID + + return dumper + + # Double quotes and backslashes embedded in element values will be + # backslash-escaped. + _re_esc = re.compile(rb'(["\\])') + + def dump(self, obj: List[Any]) -> bytes: + tokens: List[Buffer] = [] + needs_quotes = _get_needs_quotes_regexp(self.delimiter).search + + def dump_list(obj: List[Any]) -> None: + if not obj: + tokens.append(b"{}") + return + + tokens.append(b"{") + for item in obj: + if isinstance(item, list): + dump_list(item) + elif item is not None: + ad = self._dump_item(item) + if needs_quotes(ad): + if not isinstance(ad, bytes): + ad = bytes(ad) + ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"' + tokens.append(ad) + else: + tokens.append(b"NULL") + + tokens.append(self.delimiter) + + tokens[-1] = b"}" + + dump_list(obj) + + return b"".join(tokens) + + def _dump_item(self, item: Any) -> Buffer: + if self.sub_dumper: + return self.sub_dumper.dump(item) + else: + return self._tx.get_dumper(item, PY_TEXT).dump(item) + + +@cache +def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]: + """Return a regexp to recognise when a value needs quotes + + from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO + + The array output routine will put double quotes around element values if + they are empty strings, contain curly braces, delimiter characters, + double quotes, backslashes, or white space, or match the word NULL. + """ + return re.compile( + rb"""(?xi) + ^$ # the empty string + | ["{}%s\\\s] # or a char to escape + | ^null$ # or the word NULL + """ + % delimiter + ) + + +class ListBinaryDumper(BaseListDumper): + + format = pq.Format.BINARY + + def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey: + if self.oid: + return self.cls + + item = self._find_list_element(obj, format) + if item is None: + return (self.cls,) + + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.get_key(item, format)) + + def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper": + # If we have an oid we don't need to upgrade + if self.oid: + return self + + item = self._find_list_element(obj, format) + if item is None: + return ListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format.from_pq(self.format)) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + info = self._get_base_type_info(sd.oid) + dumper.oid = info.array_oid or TEXT_ARRAY_OID + + return dumper + + def dump(self, obj: List[Any]) -> bytes: + # Postgres won't take unknown for element oid: fall back on text + sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID + + if not obj: + return _pack_head(0, 0, sub_oid) + + data: List[Buffer] = [b"", b""] # placeholders to avoid a resize + dims: List[int] = [] + hasnull = 0 + + def calc_dims(L: List[Any]) -> None: + if isinstance(L, self.cls): + if not L: + raise e.DataError("lists cannot contain empty lists") + dims.append(len(L)) + calc_dims(L[0]) + + calc_dims(obj) + + def dump_list(L: List[Any], dim: int) -> None: + nonlocal hasnull + if len(L) != dims[dim]: + raise e.DataError("nested lists have inconsistent lengths") + + if dim == len(dims) - 1: + for item in L: + if item is not None: + # If we get here, the sub_dumper must have been set + ad = self.sub_dumper.dump(item) # type: ignore[union-attr] + data.append(pack_len(len(ad))) + data.append(ad) + else: + hasnull = 1 + data.append(b"\xff\xff\xff\xff") + else: + for item in L: + if not isinstance(item, self.cls): + raise e.DataError("nested lists have inconsistent depths") + dump_list(item, dim + 1) # type: ignore + + dump_list(obj, 0) + + data[0] = _pack_head(len(dims), hasnull, sub_oid) + data[1] = b"".join(_pack_dim(dim, 1) for dim in dims) + return b"".join(data) + + +class ArrayLoader(RecursiveLoader): + + delimiter = b"," + base_oid: int + + def load(self, data: Buffer) -> List[Any]: + loader = self._tx.get_loader(self.base_oid, self.format) + return _load_text(data, loader, self.delimiter) + + +class ArrayBinaryLoader(RecursiveLoader): + + format = pq.Format.BINARY + + def load(self, data: Buffer) -> List[Any]: + return _load_binary(data, self._tx) + + +def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + if not info.array_oid: + raise ValueError(f"the type info {info} doesn't describe an array") + + base: Type[Any] + adapters = context.adapters if context else postgres.adapters + + base = getattr(_psycopg, "ArrayLoader", ArrayLoader) + name = f"{info.name.title()}{base.__name__}" + attribs = { + "base_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + loader = type(name, (base,), attribs) + adapters.register_loader(info.array_oid, loader) + + loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader) + adapters.register_loader(info.array_oid, loader) + + base = ListDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + "delimiter": info.delimiter.encode(), + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + base = ListBinaryDumper + name = f"{info.name.title()}{base.__name__}" + attribs = { + "oid": info.array_oid, + "element_oid": info.oid, + } + dumper = type(name, (base,), attribs) + adapters.register_dumper(None, dumper) + + +def register_default_adapters(context: AdaptContext) -> None: + # The text dumper is more flexible as it can handle lists of mixed type, + # so register it later. + context.adapters.register_dumper(list, ListBinaryDumper) + context.adapters.register_dumper(list, ListDumper) + + +def register_all_arrays(context: AdaptContext) -> None: + """ + Associate the array oid of all the types in Loader.globals. + + This function is designed to be called once at import time, after having + registered all the base loaders. + """ + for t in context.adapters.types: + if t.array_oid: + t.register(context) + + +def _load_text( + data: Buffer, + loader: Loader, + delimiter: bytes = b",", + __re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"), +) -> List[Any]: + rv = None + stack: List[Any] = [] + a: List[Any] = [] + rv = a + load = loader.load + + # Remove the dimensions information prefix (``[...]=``) + if data and data[0] == b"["[0]: + if isinstance(data, memoryview): + data = bytes(data) + idx = data.find(b"=") + if idx == -1: + raise e.DataError("malformed array: no '=' after dimension information") + data = data[idx + 1 :] + + re_parse = _get_array_parse_regexp(delimiter) + for m in re_parse.finditer(data): + t = m.group(1) + if t == b"{": + if stack: + stack[-1].append(a) + stack.append(a) + a = [] + + elif t == b"}": + if not stack: + raise e.DataError("malformed array: unexpected '}'") + rv = stack.pop() + + else: + if not stack: + wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else "" + raise e.DataError(f"malformed array: unexpected '{wat}'") + if t == b"NULL": + v = None + else: + if t.startswith(b'"'): + t = __re_unescape.sub(rb"\1", t[1:-1]) + v = load(t) + + stack[-1].append(v) + + assert rv is not None + return rv + + +@cache +def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]: + """ + Return a regexp to tokenize an array representation into item and brackets + """ + return re.compile( + rb"""(?xi) + ( [{}] # open or closed bracket + | " (?: [^"\\] | \\. )* " # or a quoted string + | [^"{}%s\\]+ # or an unquoted non-empty string + ) ,? + """ + % delimiter + ) + + +def _load_binary(data: Buffer, tx: Transformer) -> List[Any]: + ndims, hasnull, oid = _unpack_head(data) + load = tx.get_loader(oid, PQ_BINARY).load + + if not ndims: + return [] + + p = 12 + 8 * ndims + dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)] + nelems = prod(dims) + + out: List[Any] = [None] * nelems + for i in range(nelems): + size = unpack_len(data, p)[0] + p += 4 + if size == -1: + continue + out[i] = load(data[p : p + size]) + p += size + + # fon ndims > 1 we have to aggregate the array into sub-arrays + for dim in dims[-1:0:-1]: + out = [out[i : i + dim] for i in range(0, len(out), dim)] + + return out diff --git a/psycopg/psycopg/types/bool.py b/psycopg/psycopg/types/bool.py new file mode 100644 index 0000000..db7e181 --- /dev/null +++ b/psycopg/psycopg/types/bool.py @@ -0,0 +1,51 @@ +""" +Adapters for booleans. +""" + +# Copyright (C) 2020 The Psycopg Team + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + + +class BoolDumper(Dumper): + + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"t" if obj else b"f" + + def quote(self, obj: bool) -> bytes: + return b"true" if obj else b"false" + + +class BoolBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bool"].oid + + def dump(self, obj: bool) -> bytes: + return b"\x01" if obj else b"\x00" + + +class BoolLoader(Loader): + def load(self, data: Buffer) -> bool: + return data == b"t" + + +class BoolBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> bool: + return data != b"\x00" + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(bool, BoolDumper) + adapters.register_dumper(bool, BoolBinaryDumper) + adapters.register_loader("bool", BoolLoader) + adapters.register_loader("bool", BoolBinaryLoader) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py new file mode 100644 index 0000000..1c609c3 --- /dev/null +++ b/psycopg/psycopg/types/composite.py @@ -0,0 +1,290 @@ +""" +Support for composite types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from collections import namedtuple +from typing import Any, Callable, cast, Iterator, List, Optional +from typing import Sequence, Tuple, Type + +from .. import pq +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader +from .._struct import pack_len, unpack_len +from ..postgres import TEXT_OID +from .._typeinfo import CompositeInfo as CompositeInfo # exported here +from .._encodings import _as_python_identifier + +_struct_oidlen = struct.Struct("!Ii") +_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) +_unpack_oidlen = cast( + Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from +) + + +class SequenceDumper(RecursiveDumper): + def _dump_sequence( + self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes + ) -> bytes: + if not obj: + return start + end + + parts: List[Buffer] = [start] + + for item in obj: + if item is None: + parts.append(sep) + continue + + dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) + ad = dumper.dump(item) + if not ad: + ad = b'""' + elif self._re_needs_quotes.search(ad): + ad = b'"' + self._re_esc.sub(rb"\1\1", ad) + b'"' + + parts.append(ad) + parts.append(sep) + + parts[-1] = end + + return b"".join(parts) + + _re_needs_quotes = re.compile(rb'[",\\\s()]') + _re_esc = re.compile(rb"([\\\"])") + + +class TupleDumper(SequenceDumper): + + # Should be this, but it doesn't work + # oid = postgres_types["record"].oid + + def dump(self, obj: Tuple[Any, ...]) -> bytes: + return self._dump_sequence(obj, b"(", b")", b",") + + +class TupleBinaryDumper(RecursiveDumper): + + format = pq.Format.BINARY + + # Subclasses must set an info + info: CompositeInfo + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + nfields = len(self.info.field_types) + self._tx.set_dumper_types(self.info.field_types, self.format) + self._formats = (PyFormat.from_pq(self.format),) * nfields + + def dump(self, obj: Tuple[Any, ...]) -> bytearray: + out = bytearray(pack_len(len(obj))) + adapted = self._tx.dump_sequence(obj, self._formats) + for i in range(len(obj)): + b = adapted[i] + oid = self.info.field_types[i] + if b is not None: + out += _pack_oidlen(oid, len(b)) + out += b + else: + out += _pack_oidlen(oid, -1) + + return out + + +class BaseCompositeLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]: + """ + Split a non-empty representation of a composite type into components. + + Terminators shouldn't be used in `!data` (so that both record and range + representations can be parsed). + """ + for m in self._re_tokenize.finditer(data): + if m.group(1): + yield None + elif m.group(2) is not None: + yield self._re_undouble.sub(rb"\1", m.group(2)) + else: + yield m.group(3) + + # If the final group ended in `,` there is a final NULL in the record + # that the regexp couldn't parse. + if m and m.group().endswith(b","): + yield None + + _re_tokenize = re.compile( + rb"""(?x) + (,) # an empty token, representing NULL + | " ((?: [^"] | "")*) " ,? # or a quoted string + | ([^",)]+) ,? # or an unquoted string + """ + ) + + _re_undouble = re.compile(rb'(["\\])\1') + + +class RecordLoader(BaseCompositeLoader): + def load(self, data: Buffer) -> Tuple[Any, ...]: + if data == b"()": + return () + + cast = self._tx.get_loader(TEXT_OID, self.format).load + return tuple( + cast(token) if token is not None else None + for token in self._parse_record(data[1:-1]) + ) + + +class RecordBinaryLoader(Loader): + format = pq.Format.BINARY + _types_set = False + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._tx = Transformer(context) + + def load(self, data: Buffer) -> Tuple[Any, ...]: + if not self._types_set: + self._config_types(data) + self._types_set = True + + return self._tx.load_sequence( + tuple( + data[offset : offset + length] if length != -1 else None + for _, offset, length in self._walk_record(data) + ) + ) + + def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]: + """ + Yield a sequence of (oid, offset, length) for the content of the record + """ + nfields = unpack_len(data, 0)[0] + i = 4 + for _ in range(nfields): + oid, length = _unpack_oidlen(data, i) + yield oid, i + 8, length + i += (8 + length) if length > 0 else 8 + + def _config_types(self, data: Buffer) -> None: + oids = [r[0] for r in self._walk_record(data)] + self._tx.set_loader_types(oids, self.format) + + +class CompositeLoader(RecordLoader): + + factory: Callable[..., Any] + fields_types: List[int] + _types_set = False + + def load(self, data: Buffer) -> Any: + if not self._types_set: + self._config_types(data) + self._types_set = True + + if data == b"()": + return type(self).factory() + + return type(self).factory( + *self._tx.load_sequence(tuple(self._parse_record(data[1:-1]))) + ) + + def _config_types(self, data: Buffer) -> None: + self._tx.set_loader_types(self.fields_types, self.format) + + +class CompositeBinaryLoader(RecordBinaryLoader): + + format = pq.Format.BINARY + factory: Callable[..., Any] + + def load(self, data: Buffer) -> Any: + r = super().load(data) + return type(self).factory(*r) + + +def register_composite( + info: CompositeInfo, + context: Optional[AdaptContext] = None, + factory: Optional[Callable[..., Any]] = None, +) -> None: + """Register the adapters to load and dump a composite type. + + :param info: The object with the information about the composite to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param factory: Callable to convert the sequence of attributes read from + the composite into a Python object. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested composite available?") + + # Register arrays and type info + info.register(context) + + if not factory: + factory = namedtuple( # type: ignore + _as_python_identifier(info.name), + [_as_python_identifier(n) for n in info.field_names], + ) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[BaseCompositeLoader] = type( + f"{info.name.title()}Loader", + (CompositeLoader,), + { + "factory": factory, + "fields_types": info.field_types, + }, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + loader = type( + f"{info.name.title()}BinaryLoader", + (CompositeBinaryLoader,), + {"factory": factory}, + ) + adapters.register_loader(info.oid, loader) + + # If the factory is a type, create and register dumpers for it + if isinstance(factory, type): + dumper = type( + f"{info.name.title()}BinaryDumper", + (TupleBinaryDumper,), + {"oid": info.oid, "info": info}, + ) + adapters.register_dumper(factory, dumper) + + # Default to the text dumper because it is more flexible + dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid}) + adapters.register_dumper(factory, dumper) + + info.python_type = factory + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(tuple, TupleDumper) + adapters.register_loader("record", RecordLoader) + adapters.register_loader("record", RecordBinaryLoader) diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py new file mode 100644 index 0000000..f0dfe83 --- /dev/null +++ b/psycopg/psycopg/types/datetime.py @@ -0,0 +1,754 @@ +""" +Adapters for date/time types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +import struct +from datetime import date, datetime, time, timedelta, timezone +from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from .._tz import get_tzinfo +from ..abc import AdaptContext, DumperKey +from ..adapt import Buffer, Dumper, Loader, PyFormat +from ..errors import InterfaceError, DataError +from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8 + +if TYPE_CHECKING: + from ..connection import BaseConnection + +_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset +_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack) +_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack) + +_struct_interval = struct.Struct("!qii") # microseconds, days, months +_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack) +_unpack_interval = cast( + Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack +) + +utc = timezone.utc +_pg_date_epoch_days = date(2000, 1, 1).toordinal() +_pg_datetime_epoch = datetime(2000, 1, 1) +_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc) +_py_date_min_days = date.min.toordinal() + + +class DateDumper(Dumper): + + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DateBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["date"].oid + + def dump(self, obj: date) -> bytes: + days = obj.toordinal() - _pg_date_epoch_days + return pack_int4(days) + + +class _BaseTimeDumper(Dumper): + def get_key(self, obj: time, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade to a dumper for timetz (the + # Frankenstein of the data types). + if not obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseTimeTextDumper(_BaseTimeDumper): + def dump(self, obj: time) -> bytes: + return str(obj).encode() + + +class TimeDumper(_BaseTimeTextDumper): + + oid = postgres.types["time"].oid + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzDumper(self.cls) + + +class TimeTzDumper(_BaseTimeTextDumper): + + oid = postgres.types["timetz"].oid + + +class TimeBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["time"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + return pack_int8(us) + + def upgrade(self, obj: time, format: PyFormat) -> Dumper: + if not obj.tzinfo: + return self + else: + return TimeTzBinaryDumper(self.cls) + + +class TimeTzBinaryDumper(_BaseTimeDumper): + + format = Format.BINARY + oid = postgres.types["timetz"].oid + + def dump(self, obj: time) -> bytes: + us = obj.microsecond + 1_000_000 * ( + obj.second + 60 * (obj.minute + 60 * obj.hour) + ) + off = obj.utcoffset() + assert off is not None + return _pack_timetz(us, -int(off.total_seconds())) + + +class _BaseDatetimeDumper(Dumper): + def get_key(self, obj: datetime, format: PyFormat) -> DumperKey: + # Use (cls,) to report the need to upgrade (downgrade, actually) to a + # dumper for naive timestamp. + if obj.tzinfo: + return self.cls + else: + return (self.cls,) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + raise NotImplementedError + + +class _BaseDatetimeTextDumper(_BaseDatetimeDumper): + def dump(self, obj: datetime) -> bytes: + # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD) + # the YYYY-MM-DD is always understood correctly. + return str(obj).encode() + + +class DatetimeDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamptz"].oid + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzDumper(self.cls) + + +class DatetimeNoTzDumper(_BaseDatetimeTextDumper): + + oid = postgres.types["timestamp"].oid + + +class DatetimeBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamptz"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetimetz_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + def upgrade(self, obj: datetime, format: PyFormat) -> Dumper: + if obj.tzinfo: + return self + else: + return DatetimeNoTzBinaryDumper(self.cls) + + +class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper): + + format = Format.BINARY + oid = postgres.types["timestamp"].oid + + def dump(self, obj: datetime) -> bytes: + delta = obj - _pg_datetime_epoch + micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds) + return pack_int8(micros) + + +class TimedeltaDumper(Dumper): + + oid = postgres.types["interval"].oid + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + if self.connection: + if ( + self.connection.pgconn.parameter_status(b"IntervalStyle") + == b"sql_standard" + ): + setattr(self, "dump", self._dump_sql) + + def dump(self, obj: timedelta) -> bytes: + # The comma is parsed ok by PostgreSQL but it's not documented + # and it seems brittle to rely on it. CRDB doesn't consume it well. + return str(obj).encode().replace(b",", b"") + + def _dump_sql(self, obj: timedelta) -> bytes: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + return b"%+d day %+d second %+d microsecond" % ( + obj.days, + obj.seconds, + obj.microseconds, + ) + + +class TimedeltaBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["interval"].oid + + def dump(self, obj: timedelta) -> bytes: + micros = 1_000_000 * obj.seconds + obj.microseconds + return _pack_interval(micros, obj.days, 0) + + +class DateLoader(Loader): + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> date: + if self._order == self._ORDER_YMD: + ye = data[:4] + mo = data[5:7] + da = data[8:] + elif self._order == self._ORDER_DMY: + da = data[:2] + mo = data[3:5] + ye = data[6:] + else: + mo = data[:2] + da = data[3:5] + ye = data[6:] + + try: + return date(int(ye), int(mo), int(da)) + except ValueError as ex: + s = bytes(data).decode("utf8", "replace") + if s == "infinity" or (s and len(s.split()[0]) > 10): + raise DataError(f"date too large (after year 10K): {s!r}") from None + elif s == "-infinity" or "BC" in s: + raise DataError(f"date too small (before year 1): {s!r}") from None + else: + raise DataError(f"can't parse date {s!r}: {ex}") from None + + +class DateBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> date: + days = unpack_int4(data)[0] + _pg_date_epoch_days + try: + return date.fromordinal(days) + except (ValueError, OverflowError): + if days < _py_date_min_days: + raise DataError("date too small (before year 1)") from None + else: + raise DataError("date too large (after year 10K)") from None + + +class TimeLoader(Loader): + + _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?") + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}") + + ho, mi, se, fr = m.groups() + + # Pad the fraction of second to get micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return time(int(ho), int(mi), int(se), us) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse time {s!r}: {e}") from None + + +class TimeBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val = unpack_int8(data)[0] + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + try: + return time(h, m, s, us) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimetzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) : (\d+) : (\d+) (?: \. (\d+) )? # Time and micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def load(self, data: Buffer) -> time: + m = self._re_format.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}") + + ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone + off = 60 * 60 * int(oh) + if om: + off += 60 * int(om) + if os: + off += int(os) + tz = timezone(timedelta(0, off if sgn == b"+" else -off)) + + try: + return time(int(ho), int(mi), int(se), us, tz) + except ValueError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse timetz {s!r}: {e}") from None + + +class TimetzBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> time: + val, off = _unpack_timetz(data) + + val, us = divmod(val, 1_000_000) + val, s = divmod(val, 60) + h, m = divmod(val, 60) + + try: + return time(h, m, s, us, timezone(timedelta(seconds=-off))) + except ValueError: + raise DataError(f"time not supported by Python: hour={h}") from None + + +class TimestampLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + $ + """ + ) + _re_format_pg = re.compile( + rb"""(?ix) + ^ + [a-z]+ [^a-z0-9] # DoW, separator + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+|[a-z]+) [^a-z0-9] # Month or day + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + [^a-z0-9] (\d+) # Year + $ + """ + ) + + _ORDER_YMD = 0 + _ORDER_DMY = 1 + _ORDER_MDY = 2 + _ORDER_PGDM = 3 + _ORDER_PGMD = 4 + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + + ds = _get_datestyle(self.connection) + if ds.startswith(b"I"): # ISO + self._order = self._ORDER_YMD + elif ds.startswith(b"G"): # German + self._order = self._ORDER_DMY + elif ds.startswith(b"S"): # SQL + self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY + elif ds.startswith(b"P"): # Postgres + self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD + self._re_format = self._re_format_pg + else: + raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + if self._order == self._ORDER_YMD: + ye, mo, da, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_DMY: + da, mo, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + elif self._order == self._ORDER_MDY: + mo, da, ye, ho, mi, se, fr = m.groups() + imo = int(mo) + else: + if self._order == self._ORDER_PGDM: + da, mo, ho, mi, se, fr, ye = m.groups() + else: + mo, da, ho, mi, se, fr, ye = m.groups() + try: + imo = _month_abbr[mo] + except KeyError: + s = mo.decode("utf8", "replace") + raise DataError(f"can't parse month: {s!r}") from None + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + try: + return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us) + except ValueError as ex: + raise _get_timestamp_load_error(self.connection, data, ex) from None + + +class TimestampBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + return _pg_datetime_epoch + timedelta(microseconds=micros) + except OverflowError: + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class TimestamptzLoader(Loader): + + _re_format = re.compile( + rb"""(?ix) + ^ + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Date + (?: T | [^a-z0-9] ) # Separator, including T + (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+) # Time + (?: \.(\d+) )? # Micros + ([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? # Timezone + $ + """ + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + ds = _get_datestyle(self.connection) + if not ds.startswith(b"I"): # not ISO + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> datetime: + m = self._re_format.match(data) + if not m: + raise _get_timestamp_load_error(self.connection, data) from None + + ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups() + + # Pad the fraction of second to get the micros + if fr: + us = int(fr) + if len(fr) < 6: + us *= _uspad[len(fr)] + else: + us = 0 + + # Calculate timezone offset + soff = 60 * 60 * int(oh) + if om: + soff += 60 * int(om) + if os: + soff += int(os) + tzoff = timedelta(0, soff if sgn == b"+" else -soff) + + # The return value is a datetime with the timezone of the connection + # (in order to be consistent with the binary loader, which is the only + # thing it can return). So create a temporary datetime object, in utc, + # shift it by the offset parsed from the timestamp, and then move it to + # the connection timezone. + dt = None + ex: Exception + try: + dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc) + return (dt - tzoff).astimezone(self._timezone) + except OverflowError as e: + # If we have created the temporary 'dt' it means that we have a + # datetime close to max, the shift pushed it past max, overflowing. + # In this case return the datetime in a fixed offset timezone. + if dt is not None: + return dt.replace(tzinfo=timezone(tzoff)) + else: + ex = e + except ValueError as e: + ex = e + + raise _get_timestamp_load_error(self.connection, data, ex) from None + + def _load_notimpl(self, data: Buffer) -> datetime: + s = bytes(data).decode("utf8", "replace") + ds = _get_datestyle(self.connection).decode("ascii") + raise NotImplementedError( + f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" + ) + + +class TimestamptzBinaryLoader(Loader): + + format = Format.BINARY + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) + + def load(self, data: Buffer) -> datetime: + micros = unpack_int8(data)[0] + try: + ts = _pg_datetimetz_epoch + timedelta(microseconds=micros) + return ts.astimezone(self._timezone) + except OverflowError: + # If we were asked about a timestamp which would overflow in UTC, + # but not in the desired timezone (e.g. datetime.max at Chicago + # timezone) we can still save the day by shifting the value by the + # timezone offset and then replacing the timezone. + if self._timezone: + utcoff = self._timezone.utcoffset( + datetime.min if micros < 0 else datetime.max + ) + if utcoff: + usoff = 1_000_000 * int(utcoff.total_seconds()) + try: + ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff) + except OverflowError: + pass # will raise downstream + else: + return ts.replace(tzinfo=self._timezone) + + if micros <= 0: + raise DataError("timestamp too small (before year 1)") from None + else: + raise DataError("timestamp too large (after year 10K)") from None + + +class IntervalLoader(Loader): + + _re_interval = re.compile( + rb""" + (?: ([-+]?\d+) \s+ years? \s* )? # Years + (?: ([-+]?\d+) \s+ mons? \s* )? # Months + (?: ([-+]?\d+) \s+ days? \s* )? # Days + (?: ([-+])? (\d+) : (\d+) : (\d+ (?:\.\d+)?) # Time + )? + """, + re.VERBOSE, + ) + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if self.connection: + ints = self.connection.pgconn.parameter_status(b"IntervalStyle") + if ints != b"postgres": + setattr(self, "load", self._load_notimpl) + + def load(self, data: Buffer) -> timedelta: + m = self._re_interval.match(data) + if not m: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}") + + ye, mo, da, sgn, ho, mi, se = m.groups() + days = 0 + seconds = 0.0 + + if ye: + days += 365 * int(ye) + if mo: + days += 30 * int(mo) + if da: + days += int(da) + + if ho: + seconds = 3600 * int(ho) + 60 * int(mi) + float(se) + if sgn == b"-": + seconds = -seconds + + try: + return timedelta(days=days, seconds=seconds) + except OverflowError as e: + s = bytes(data).decode("utf8", "replace") + raise DataError(f"can't parse interval {s!r}: {e}") from None + + def _load_notimpl(self, data: Buffer) -> timedelta: + s = bytes(data).decode("utf8", "replace") + ints = ( + self.connection + and self.connection.pgconn.parameter_status(b"IntervalStyle") + or b"unknown" + ).decode("utf8", "replace") + raise NotImplementedError( + f"can't parse interval with IntervalStyle {ints}: {s!r}" + ) + + +class IntervalBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> timedelta: + micros, days, months = _unpack_interval(data) + if months > 0: + years, months = divmod(months, 12) + days = days + 30 * months + 365 * years + elif months < 0: + years, months = divmod(-months, 12) + days = days - 30 * months - 365 * years + + try: + return timedelta(days=days, microseconds=micros) + except OverflowError as e: + raise DataError(f"can't parse interval: {e}") from None + + +def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes: + if conn: + ds = conn.pgconn.parameter_status(b"DateStyle") + if ds: + return ds + + return b"ISO, DMY" + + +def _get_timestamp_load_error( + conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None +) -> Exception: + s = bytes(data).decode("utf8", "replace") + + def is_overflow(s: str) -> bool: + if not s: + return False + + ds = _get_datestyle(conn) + if not ds.startswith(b"P"): # Postgres + return len(s.split()[0]) > 10 # date is first token + else: + return len(s.split()[-1]) > 4 # year is last token + + if s == "-infinity" or s.endswith("BC"): + return DataError("timestamp too small (before year 1): {s!r}") + elif s == "infinity" or is_overflow(s): + return DataError(f"timestamp too large (after year 10K): {s!r}") + else: + return DataError(f"can't parse timestamp {s!r}: {ex or '(unknown)'}") + + +_month_abbr = { + n: i + for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1) +} + +# Pad to get microseconds from a fraction of seconds +_uspad = [0, 100_000, 10_000, 1_000, 100, 10, 1] + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("datetime.date", DateDumper) + adapters.register_dumper("datetime.date", DateBinaryDumper) + + # first register dumpers for 'timetz' oid, then the proper ones on time type. + adapters.register_dumper("datetime.time", TimeTzDumper) + adapters.register_dumper("datetime.time", TimeTzBinaryDumper) + adapters.register_dumper("datetime.time", TimeDumper) + adapters.register_dumper("datetime.time", TimeBinaryDumper) + + # first register dumpers for 'timestamp' oid, then the proper ones + # on the datetime type. + adapters.register_dumper("datetime.datetime", DatetimeNoTzDumper) + adapters.register_dumper("datetime.datetime", DatetimeNoTzBinaryDumper) + adapters.register_dumper("datetime.datetime", DatetimeDumper) + adapters.register_dumper("datetime.datetime", DatetimeBinaryDumper) + + adapters.register_dumper("datetime.timedelta", TimedeltaDumper) + adapters.register_dumper("datetime.timedelta", TimedeltaBinaryDumper) + + adapters.register_loader("date", DateLoader) + adapters.register_loader("date", DateBinaryLoader) + adapters.register_loader("time", TimeLoader) + adapters.register_loader("time", TimeBinaryLoader) + adapters.register_loader("timetz", TimetzLoader) + adapters.register_loader("timetz", TimetzBinaryLoader) + adapters.register_loader("timestamp", TimestampLoader) + adapters.register_loader("timestamp", TimestampBinaryLoader) + adapters.register_loader("timestamptz", TimestamptzLoader) + adapters.register_loader("timestamptz", TimestamptzBinaryLoader) + adapters.register_loader("interval", IntervalLoader) + adapters.register_loader("interval", IntervalBinaryLoader) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py new file mode 100644 index 0000000..d3c7387 --- /dev/null +++ b/psycopg/psycopg/types/enum.py @@ -0,0 +1,177 @@ +""" +Adapters for the enum type. +""" +from enum import Enum +from typing import Any, Dict, Generic, Optional, Mapping, Sequence +from typing import Tuple, Type, TypeVar, Union, cast +from typing_extensions import TypeAlias + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from .._encodings import conn_encoding +from .._typeinfo import EnumInfo as EnumInfo # exported here + +E = TypeVar("E", bound=Enum) + +EnumDumpMap: TypeAlias = Dict[E, bytes] +EnumLoadMap: TypeAlias = Dict[bytes, E] +EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None] + + +class _BaseEnumLoader(Loader, Generic[E]): + """ + Loader for a specific Enum class + """ + + enum: Type[E] + _load_map: EnumLoadMap[E] + + def load(self, data: Buffer) -> E: + if not isinstance(data, bytes): + data = bytes(data) + + try: + return self._load_map[data] + except KeyError: + enc = conn_encoding(self.connection) + label = data.decode(enc, "replace") + raise e.DataError( + f"bad member for enum {self.enum.__qualname__}: {label!r}" + ) + + +class _BaseEnumDumper(Dumper, Generic[E]): + """ + Dumper for a specific Enum class + """ + + enum: Type[E] + _dump_map: EnumDumpMap[E] + + def dump(self, value: E) -> Buffer: + return self._dump_map[value] + + +class EnumDumper(Dumper): + """ + Dumper for a generic Enum class + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._encoding = conn_encoding(self.connection) + + def dump(self, value: E) -> Buffer: + return value.name.encode(self._encoding) + + +class EnumBinaryDumper(EnumDumper): + format = Format.BINARY + + +def register_enum( + info: EnumInfo, + context: Optional[AdaptContext] = None, + enum: Optional[Type[E]] = None, + *, + mapping: EnumMapping[E] = None, +) -> None: + """Register the adapters to load and dump a enum type. + + :param info: The object with the information about the enum to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + :param enum: Python enum type matching to the PostgreSQL one. If `!None`, + a new enum will be generated and exposed as `EnumInfo.enum`. + :param mapping: Override the mapping between `!enum` members and `!info` + labels. + """ + + if not info: + raise TypeError("no info passed. Is the requested enum available?") + + if enum is None: + enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__)) + + info.enum = enum + adapters = context.adapters if context else postgres.adapters + info.register(context) + + load_map = _make_load_map(info, enum, mapping, context) + attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map} + + name = f"{info.name.title()}Loader" + loader = type(name, (_BaseEnumLoader,), attribs) + adapters.register_loader(info.oid, loader) + + name = f"{info.name.title()}BinaryLoader" + loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY}) + adapters.register_loader(info.oid, loader) + + dump_map = _make_dump_map(info, enum, mapping, context) + attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map} + + name = f"{enum.__name__}Dumper" + dumper = type(name, (_BaseEnumDumper,), attribs) + adapters.register_dumper(info.enum, dumper) + + name = f"{enum.__name__}BinaryDumper" + dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY}) + adapters.register_dumper(info.enum, dumper) + + +def _make_load_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumLoadMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumLoadMap[E] = {} + for label in info.labels: + try: + member = enum[label] + except KeyError: + # tolerate a missing enum, assuming it won't be used. If it is we + # will get a DataError on fetch. + pass + else: + rv[label.encode(enc)] = member + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[label.encode(enc)] = member + + return rv + + +def _make_dump_map( + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumDumpMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumDumpMap[E] = {} + for member in enum: + rv[member] = member.name.encode(enc) + + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[member] = label.encode(enc) + + return rv + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, EnumBinaryDumper) + context.adapters.register_dumper(Enum, EnumDumper) diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py new file mode 100644 index 0000000..e1ab1d5 --- /dev/null +++ b/psycopg/psycopg/types/hstore.py @@ -0,0 +1,131 @@ +""" +Dict to hstore adaptation +""" + +# Copyright (C) 2021 The Psycopg Team + +import re +from typing import Dict, List, Optional +from typing_extensions import TypeAlias + +from .. import errors as e +from .. import postgres +from ..abc import Buffer, AdaptContext +from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader +from ..postgres import TEXT_OID +from .._typeinfo import TypeInfo + +_re_escape = re.compile(r'(["\\])') +_re_unescape = re.compile(r"\\(.)") + +_re_hstore = re.compile( + r""" + # hstore key: + # a string of normal or escaped chars + "((?: [^"\\] | \\. )*)" + \s*=>\s* # hstore value + (?: + NULL # the value can be null - not caught + # or a quoted string like the key + | "((?: [^"\\] | \\. )*)" + ) + (?:\s*,\s*|$) # pairs separated by comma or end of string. +""", + re.VERBOSE, +) + + +Hstore: TypeAlias = Dict[str, Optional[str]] + + +class BaseHstoreDumper(RecursiveDumper): + def dump(self, obj: Hstore) -> Buffer: + if not obj: + return b"" + + tokens: List[str] = [] + + def add_token(s: str) -> None: + tokens.append('"') + tokens.append(_re_escape.sub(r"\\\1", s)) + tokens.append('"') + + for k, v in obj.items(): + + if not isinstance(k, str): + raise e.DataError("hstore keys can only be strings") + add_token(k) + + tokens.append("=>") + + if v is None: + tokens.append("NULL") + elif not isinstance(v, str): + raise e.DataError("hstore keys can only be strings") + else: + add_token(v) + + tokens.append(",") + + del tokens[-1] + data = "".join(tokens) + dumper = self._tx.get_dumper(data, PyFormat.TEXT) + return dumper.dump(data) + + +class HstoreLoader(RecursiveLoader): + def load(self, data: Buffer) -> Hstore: + loader = self._tx.get_loader(TEXT_OID, self.format) + s: str = loader.load(data) + + rv: Hstore = {} + start = 0 + for m in _re_hstore.finditer(s): + if m is None or m.start() != start: + raise e.DataError(f"error parsing hstore pair at char {start}") + k = _re_unescape.sub(r"\1", m.group(1)) + v = m.group(2) + if v is not None: + v = _re_unescape.sub(r"\1", v) + + rv[k] = v + start = m.end() + + if start < len(s): + raise e.DataError(f"error parsing hstore: unparsed data after char {start}") + + return rv + + +def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump hstore. + + :param info: The object with the information about the hstore type. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'hstore' extension loaded?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # Generate and register a customized text dumper + class HstoreDumper(BaseHstoreDumper): + oid = info.oid + + adapters.register_dumper(dict, HstoreDumper) + + # register the text loader on the oid + adapters.register_loader(info.oid, HstoreLoader) diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py new file mode 100644 index 0000000..a80e0e4 --- /dev/null +++ b/psycopg/psycopg/types/json.py @@ -0,0 +1,232 @@ +""" +Adapers for JSON types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import json +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +from .. import abc +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap +from ..errors import DataError + +JsonDumpsFunction = Callable[[Any], str] +JsonLoadsFunction = Callable[[Union[str, bytes]], Any] + + +def set_json_dumps( + dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON serialisation function to store JSON objects in the database. + + :param dumps: The dump function to use. + :type dumps: `!Callable[[Any], str]` + :param context: Where to use the `!dumps` function. If not specified, use it + globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default dumping JSON uses the builtin `json.dumps`. You can override + it to use a different JSON library or to use customised arguments. + + If the `Json` wrapper specified a `!dumps` function, use it in precedence + of the one set by this function. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonDumper._dumps = dumps + else: + adapters = context.adapters + + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + (Json, PyFormat.BINARY), + (Json, PyFormat.TEXT), + (Jsonb, PyFormat.BINARY), + (Jsonb, PyFormat.TEXT), + ] + dumper: Type[_JsonDumper] + for wrapper, format in grid: + base = _get_current_dumper(adapters, wrapper, format) + name = base.__name__ + if not base.__name__.startswith("Custom"): + name = f"Custom{name}" + dumper = type(name, (base,), {"_dumps": dumps}) + adapters.register_dumper(wrapper, dumper) + + +def set_json_loads( + loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None +) -> None: + """ + Set the JSON parsing function to fetch JSON objects from the database. + + :param loads: The load function to use. + :type loads: `!Callable[[bytes], Any]` + :param context: Where to use the `!loads` function. If not specified, use + it globally. + :type context: `~psycopg.Connection` or `~psycopg.Cursor` + + By default loading JSON uses the builtin `json.loads`. You can override + it to use a different JSON library or to use customised arguments. + """ + if context is None: + # If changing load function globally, just change the default on the + # global class + _JsonLoader._loads = loads + else: + # If the scope is smaller than global, create subclassess and register + # them in the appropriate scope. + grid = [ + ("json", JsonLoader), + ("json", JsonBinaryLoader), + ("jsonb", JsonbLoader), + ("jsonb", JsonbBinaryLoader), + ] + loader: Type[_JsonLoader] + for tname, base in grid: + loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads}) + context.adapters.register_loader(tname, loader) + + +class _JsonWrapper: + __slots__ = ("obj", "dumps") + + def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None): + self.obj = obj + self.dumps = dumps + + def __repr__(self) -> str: + sobj = repr(self.obj) + if len(sobj) > 40: + sobj = f"{sobj[:35]} ... ({len(sobj)} chars)" + return f"{self.__class__.__name__}({sobj})" + + +class Json(_JsonWrapper): + __slots__ = () + + +class Jsonb(_JsonWrapper): + __slots__ = () + + +class _JsonDumper(Dumper): + + # The globally used JSON dumps() function. It can be changed globally (by + # set_json_dumps) or by a subclass. + _dumps: JsonDumpsFunction = json.dumps + + def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None): + super().__init__(cls, context) + self.dumps = self.__class__._dumps + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return dumps(obj.obj).encode() + + +class JsonDumper(_JsonDumper): + + oid = postgres.types["json"].oid + + +class JsonBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["json"].oid + + +class JsonbDumper(_JsonDumper): + + oid = postgres.types["jsonb"].oid + + +class JsonbBinaryDumper(_JsonDumper): + + format = Format.BINARY + oid = postgres.types["jsonb"].oid + + def dump(self, obj: _JsonWrapper) -> bytes: + dumps = obj.dumps or self.dumps + return b"\x01" + dumps(obj.obj).encode() + + +class _JsonLoader(Loader): + + # The globally used JSON loads() function. It can be changed globally (by + # set_json_loads) or by a subclass. + _loads: JsonLoadsFunction = json.loads + + def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None): + super().__init__(oid, context) + self.loads = self.__class__._loads + + def load(self, data: Buffer) -> Any: + # json.loads() cannot work on memoryview. + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +class JsonLoader(_JsonLoader): + pass + + +class JsonbLoader(_JsonLoader): + pass + + +class JsonBinaryLoader(_JsonLoader): + format = Format.BINARY + + +class JsonbBinaryLoader(_JsonLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Any: + if data and data[0] != 1: + raise DataError("unknown jsonb binary format: {data[0]}") + data = data[1:] + if not isinstance(data, bytes): + data = bytes(data) + return self.loads(data) + + +def _get_current_dumper( + adapters: AdaptersMap, cls: type, format: PyFormat +) -> Type[abc.Dumper]: + try: + return adapters.get_dumper(cls, format) + except e.ProgrammingError: + return _default_dumpers[cls, format] + + +_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = { + (Json, PyFormat.BINARY): JsonBinaryDumper, + (Json, PyFormat.TEXT): JsonDumper, + (Jsonb, PyFormat.BINARY): JsonbBinaryDumper, + (Jsonb, PyFormat.TEXT): JsonDumper, +} + + +def register_default_adapters(context: abc.AdaptContext) -> None: + adapters = context.adapters + + # Currently json binary format is nothing different than text, maybe with + # an extra memcopy we can avoid. + adapters.register_dumper(Json, JsonBinaryDumper) + adapters.register_dumper(Json, JsonDumper) + adapters.register_dumper(Jsonb, JsonbBinaryDumper) + adapters.register_dumper(Jsonb, JsonbDumper) + adapters.register_loader("json", JsonLoader) + adapters.register_loader("jsonb", JsonbLoader) + adapters.register_loader("json", JsonBinaryLoader) + adapters.register_loader("jsonb", JsonbBinaryLoader) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py new file mode 100644 index 0000000..3eaa7f1 --- /dev/null +++ b/psycopg/psycopg/types/multirange.py @@ -0,0 +1,514 @@ +""" +Support for multirange types adaptation. +""" + +# Copyright (C) 2021 The Psycopg Team + +from decimal import Decimal +from typing import Any, Generic, List, Iterable +from typing import MutableSequence, Optional, Type, Union, overload +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here + +from .range import Range, T, load_range_text, load_range_binary +from .range import dump_range_text, dump_range_binary, fail_dump + + +class Multirange(MutableSequence[Range[T]]): + """Python representation for a PostgreSQL multirange type. + + :param items: Sequence of ranges to initialise the object. + """ + + def __init__(self, items: Iterable[Range[T]] = ()): + self._ranges: List[Range[T]] = list(map(self._check_type, items)) + + def _check_type(self, item: Any) -> Range[Any]: + if not isinstance(item, Range): + raise TypeError( + f"Multirange is a sequence of Range, got {type(item).__name__}" + ) + return item + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._ranges!r})" + + def __str__(self) -> str: + return f"{{{', '.join(map(str, self._ranges))}}}" + + @overload + def __getitem__(self, index: int) -> Range[T]: + ... + + @overload + def __getitem__(self, index: slice) -> "Multirange[T]": + ... + + def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]": + if isinstance(index, int): + return self._ranges[index] + else: + return Multirange(self._ranges[index]) + + def __len__(self) -> int: + return len(self._ranges) + + @overload + def __setitem__(self, index: int, value: Range[T]) -> None: + ... + + @overload + def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: + ... + + def __setitem__( + self, + index: Union[int, slice], + value: Union[Range[T], Iterable[Range[T]]], + ) -> None: + if isinstance(index, int): + self._check_type(value) + self._ranges[index] = self._check_type(value) + elif not isinstance(value, Iterable): + raise TypeError("can only assign an iterable") + else: + value = map(self._check_type, value) + self._ranges[index] = value + + def __delitem__(self, index: Union[int, slice]) -> None: + del self._ranges[index] + + def insert(self, index: int, value: Range[T]) -> None: + self._ranges.insert(index, self._check_type(value)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return False + return self._ranges == other._ranges + + # Order is arbitrary but consistent + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges < other._ranges + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, Multirange): + return NotImplemented + return self._ranges > other._ranges + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + +# Subclasses to specify a specific subtype. Usually not needed + + +class Int4Multirange(Multirange[int]): + pass + + +class Int8Multirange(Multirange[int]): + pass + + +class NumericMultirange(Multirange[Decimal]): + pass + + +class DateMultirange(Multirange[date]): + pass + + +class TimestampMultirange(Multirange[datetime]): + pass + + +class TimestamptzMultirange(Multirange[datetime]): + pass + + +class BaseMultirangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Multirange: + return self + + item = self._get_item(obj) + if item is None: + return MultirangeDumper(self.cls) + + dumper: BaseMultirangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = MultirangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_multirange_oid(TEXT_OID) + else: + dumper.oid = self._get_multirange_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Multirange[Any]) -> Any: + """ + Return a member representative of the multirange + """ + for r in obj: + if r.lower is not None: + return r.lower + if r.upper is not None: + return r.upper + return None + + def _get_multirange_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class MultirangeDumper(BaseMultirangeDumper): + """ + Dumper for multirange types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Multirange[Any]) -> Buffer: + if not obj: + return b"{}" + + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [b"{"] + for r in obj: + out.append(dump_range_text(r, dump)) + out.append(b",") + out[-1] = b"}" + return b"".join(out) + + +class MultirangeBinaryDumper(BaseMultirangeDumper): + + format = Format.BINARY + + def dump(self, obj: Multirange[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + out: List[Buffer] = [pack_len(len(obj))] + for r in obj: + data = dump_range_binary(r, dump) + out.append(pack_len(len(data))) + out.append(data) + return b"".join(out) + + +class BaseMultirangeLoader(RecursiveLoader, Generic[T]): + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class MultirangeLoader(BaseMultirangeLoader[T]): + def load(self, data: Buffer) -> Multirange[T]: + if not data or data[0] != _START_INT: + raise e.DataError( + "malformed multirange starting with" + f" {bytes(data[:1]).decode('utf8', 'replace')}" + ) + + out = Multirange[T]() + if data == b"{}": + return out + + pos = 1 + data = data[pos:] + try: + while True: + r, pos = load_range_text(data, self._load) + out.append(r) + + sep = data[pos] # can raise IndexError + if sep == _SEP_INT: + data = data[pos + 1 :] + continue + elif sep == _END_INT: + if len(data) == pos + 1: + return out + else: + raise e.DataError( + "malformed multirange: data after closing brace" + ) + else: + raise e.DataError( + f"malformed multirange: found unexpected {chr(sep)}" + ) + + except IndexError: + raise e.DataError("malformed multirange: separator missing") + + return out + + +_SEP_INT = ord(",") +_START_INT = ord("{") +_END_INT = ord("}") + + +class MultirangeBinaryLoader(BaseMultirangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Multirange[T]: + nelems = unpack_len(data, 0)[0] + pos = 4 + out = Multirange[T]() + for i in range(nelems): + length = unpack_len(data, pos)[0] + pos += 4 + out.append(load_range_binary(data[pos : pos + length], self._load)) + pos += length + + if pos != len(data): + raise e.DataError("unexpected trailing data in multirange") + + return out + + +def register_multirange( + info: MultirangeInfo, context: Optional[AdaptContext] = None +) -> None: + """Register the adapters to load and dump a multirange type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested multirange available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[MultirangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (MultirangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[MultirangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (MultirangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeDumper(MultirangeDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeDumper(MultirangeDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeDumper(MultirangeDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeDumper(MultirangeDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeDumper(MultirangeDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeDumper(MultirangeDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Binary dumpers for builtin multirange types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int4multirange"].oid + + +class Int8MultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["int8multirange"].oid + + +class NumericMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["nummultirange"].oid + + +class DateMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["datemultirange"].oid + + +class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tsmultirange"].oid + + +class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper): + oid = postgres.types["tstzmultirange"].oid + + +# Text loaders for builtin multirange types + + +class Int4MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeLoader(MultirangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeLoader(MultirangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeLoader(MultirangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeLoader(MultirangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin multirange types + + +class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Multirange, MultirangeBinaryDumper) + adapters.register_dumper(Multirange, MultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeDumper) + adapters.register_dumper(DateMultirange, DateMultirangeDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper) + adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper) + adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper) + adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper) + adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper) + adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper) + adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper) + adapters.register_loader("int4multirange", Int4MultirangeLoader) + adapters.register_loader("int8multirange", Int8MultirangeLoader) + adapters.register_loader("nummultirange", NumericMultirangeLoader) + adapters.register_loader("datemultirange", DateMultirangeLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader) + adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader) + adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader) + adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader) + adapters.register_loader("datemultirange", DateMultirangeBinaryLoader) + adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader) + adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader) diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py new file mode 100644 index 0000000..2f2c05b --- /dev/null +++ b/psycopg/psycopg/types/net.py @@ -0,0 +1,206 @@ +""" +Adapters for network types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, Type, Union, TYPE_CHECKING +from typing_extensions import TypeAlias + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import ipaddress + +Address: TypeAlias = Union["ipaddress.IPv4Address", "ipaddress.IPv6Address"] +Interface: TypeAlias = Union["ipaddress.IPv4Interface", "ipaddress.IPv6Interface"] +Network: TypeAlias = Union["ipaddress.IPv4Network", "ipaddress.IPv6Network"] + +# These objects will be imported lazily +ip_address: Callable[[str], Address] = None # type: ignore[assignment] +ip_interface: Callable[[str], Interface] = None # type: ignore[assignment] +ip_network: Callable[[str], Network] = None # type: ignore[assignment] +IPv4Address: "Type[ipaddress.IPv4Address]" = None # type: ignore[assignment] +IPv6Address: "Type[ipaddress.IPv6Address]" = None # type: ignore[assignment] +IPv4Interface: "Type[ipaddress.IPv4Interface]" = None # type: ignore[assignment] +IPv6Interface: "Type[ipaddress.IPv6Interface]" = None # type: ignore[assignment] +IPv4Network: "Type[ipaddress.IPv4Network]" = None # type: ignore[assignment] +IPv6Network: "Type[ipaddress.IPv6Network]" = None # type: ignore[assignment] + +PGSQL_AF_INET = 2 +PGSQL_AF_INET6 = 3 +IPV4_PREFIXLEN = 32 +IPV6_PREFIXLEN = 128 + + +class _LazyIpaddress: + def _ensure_module(self) -> None: + global ip_address, ip_interface, ip_network + global IPv4Address, IPv6Address, IPv4Interface, IPv6Interface + global IPv4Network, IPv6Network + + if ip_address is None: + from ipaddress import ip_address, ip_interface, ip_network + from ipaddress import IPv4Address, IPv6Address + from ipaddress import IPv4Interface, IPv6Interface + from ipaddress import IPv4Network, IPv6Network + + +class InterfaceDumper(Dumper): + + oid = postgres.types["inet"].oid + + def dump(self, obj: Interface) -> bytes: + return str(obj).encode() + + +class NetworkDumper(Dumper): + + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + return str(obj).encode() + + +class _AIBinaryDumper(Dumper): + format = Format.BINARY + oid = postgres.types["inet"].oid + + +class AddressBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Address) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.max_prefixlen, 0, len(packed))) + return head + packed + + +class InterfaceBinaryDumper(_AIBinaryDumper): + def dump(self, obj: Interface) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.network.prefixlen, 0, len(packed))) + return head + packed + + +class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress): + """Either an address or an interface to inet + + Used when looking up by oid. + """ + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._ensure_module() + + def dump(self, obj: Union[Address, Interface]) -> bytes: + packed = obj.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + if isinstance(obj, (IPv4Interface, IPv6Interface)): + prefixlen = obj.network.prefixlen + else: + prefixlen = obj.max_prefixlen + + head = bytes((family, prefixlen, 0, len(packed))) + return head + packed + + +class NetworkBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["cidr"].oid + + def dump(self, obj: Network) -> bytes: + packed = obj.network_address.packed + family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6 + head = bytes((family, obj.prefixlen, 1, len(packed))) + return head + packed + + +class _LazyIpaddressLoader(Loader, _LazyIpaddress): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._ensure_module() + + +class InetLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + if b"/" in data: + return ip_interface(data.decode()) + else: + return ip_address(data.decode()) + + +class InetBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Union[Address, Interface]: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + if prefix == IPV4_PREFIXLEN: + return IPv4Address(packed) + else: + return IPv4Interface((packed, prefix)) + else: + if prefix == IPV6_PREFIXLEN: + return IPv6Address(packed) + else: + return IPv6Interface((packed, prefix)) + + +class CidrLoader(_LazyIpaddressLoader): + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + return ip_network(data.decode()) + + +class CidrBinaryLoader(_LazyIpaddressLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Network: + if isinstance(data, memoryview): + data = bytes(data) + + prefix = data[1] + packed = data[4:] + if data[0] == PGSQL_AF_INET: + return IPv4Network((packed, prefix)) + else: + return IPv6Network((packed, prefix)) + + return ip_network(data.decode()) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("ipaddress.IPv4Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Address", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkDumper) + adapters.register_dumper("ipaddress.IPv4Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Address", AddressBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Interface", InterfaceBinaryDumper) + adapters.register_dumper("ipaddress.IPv4Network", NetworkBinaryDumper) + adapters.register_dumper("ipaddress.IPv6Network", NetworkBinaryDumper) + adapters.register_dumper(None, InetBinaryDumper) + adapters.register_loader("inet", InetLoader) + adapters.register_loader("inet", InetBinaryLoader) + adapters.register_loader("cidr", CidrLoader) + adapters.register_loader("cidr", CidrBinaryLoader) diff --git a/psycopg/psycopg/types/none.py b/psycopg/psycopg/types/none.py new file mode 100644 index 0000000..2ab857c --- /dev/null +++ b/psycopg/psycopg/types/none.py @@ -0,0 +1,25 @@ +""" +Adapters for None. +""" + +# Copyright (C) 2020 The Psycopg Team + +from ..abc import AdaptContext, NoneType +from ..adapt import Dumper + + +class NoneDumper(Dumper): + """ + Not a complete dumper as it doesn't implement dump(), but it implements + quote(), so it can be used in sql composition. + """ + + def dump(self, obj: None) -> bytes: + raise NotImplementedError("NULL is passed to Postgres in other ways") + + def quote(self, obj: None) -> bytes: + return b"NULL" + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(NoneType, NoneDumper) diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py new file mode 100644 index 0000000..1bd9329 --- /dev/null +++ b/psycopg/psycopg/types/numeric.py @@ -0,0 +1,515 @@ +""" +Adapers for numeric types. +""" + +# Copyright (C) 2020 The Psycopg Team + +import struct +from math import log +from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast +from decimal import Decimal, DefaultContext, Context + +from .. import postgres +from .. import errors as e +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader, PyFormat +from .._struct import pack_int2, pack_uint2, unpack_int2 +from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4 +from .._struct import pack_int8, unpack_int8 +from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8 + +# Exposed here +from .._wrappers import ( + Int2 as Int2, + Int4 as Int4, + Int8 as Int8, + IntNumeric as IntNumeric, + Oid as Oid, + Float4 as Float4, + Float8 as Float8, +) + + +class _IntDumper(Dumper): + def dump(self, obj: Any) -> Buffer: + t = type(obj) + if t is not int: + # Convert to int in order to dump IntEnum correctly + if issubclass(t, int): + obj = int(obj) + else: + raise e.DataError(f"integer expected, got {type(obj).__name__!r}") + + return str(obj).encode() + + def quote(self, obj: Any) -> Buffer: + value = self.dump(obj) + return value if obj >= 0 else b" " + value + + +class _SpecialValuesDumper(Dumper): + + _special: Dict[bytes, bytes] = {} + + def dump(self, obj: Any) -> bytes: + return str(obj).encode() + + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + + if value in self._special: + return self._special[value] + + return value if obj >= 0 else b" " + value + + +class FloatDumper(_SpecialValuesDumper): + + oid = postgres.types["float8"].oid + + _special = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", + } + + +class Float4Dumper(FloatDumper): + oid = postgres.types["float4"].oid + + +class FloatBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["float8"].oid + + def dump(self, obj: float) -> bytes: + return pack_float8(obj) + + +class Float4BinaryDumper(FloatBinaryDumper): + + oid = postgres.types["float4"].oid + + def dump(self, obj: float) -> bytes: + return pack_float4(obj) + + +class DecimalDumper(_SpecialValuesDumper): + + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> bytes: + if obj.is_nan(): + # cover NaN and sNaN + return b"NaN" + else: + return str(obj).encode() + + _special = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", + } + + +class Int2Dumper(_IntDumper): + oid = postgres.types["int2"].oid + + +class Int4Dumper(_IntDumper): + oid = postgres.types["int4"].oid + + +class Int8Dumper(_IntDumper): + oid = postgres.types["int8"].oid + + +class IntNumericDumper(_IntDumper): + oid = postgres.types["numeric"].oid + + +class OidDumper(_IntDumper): + oid = postgres.types["oid"].oid + + +class IntDumper(Dumper): + def dump(self, obj: Any) -> bytes: + raise TypeError( + f"{type(self).__name__} is a dispatcher to other dumpers:" + " dump() is not supposed to be called" + ) + + def get_key(self, obj: int, format: PyFormat) -> type: + return self.upgrade(obj, format).cls + + _int2_dumper = Int2Dumper(Int2) + _int4_dumper = Int4Dumper(Int4) + _int8_dumper = Int8Dumper(Int8) + _int_numeric_dumper = IntNumericDumper(IntNumeric) + + def upgrade(self, obj: int, format: PyFormat) -> Dumper: + if -(2**31) <= obj < 2**31: + if -(2**15) <= obj < 2**15: + return self._int2_dumper + else: + return self._int4_dumper + else: + if -(2**63) <= obj < 2**63: + return self._int8_dumper + else: + return self._int_numeric_dumper + + +class Int2BinaryDumper(Int2Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int2(obj) + + +class Int4BinaryDumper(Int4Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int4(obj) + + +class Int8BinaryDumper(Int8Dumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_int8(obj) + + +# Ratio between number of bits required to store a number and number of pg +# decimal digits required. +BIT_PER_PGDIGIT = log(2) / log(10_000) + + +class IntNumericBinaryDumper(IntNumericDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> Buffer: + return dump_int_to_numeric_binary(obj) + + +class OidBinaryDumper(OidDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + return pack_uint4(obj) + + +class IntBinaryDumper(IntDumper): + + format = Format.BINARY + + _int2_dumper = Int2BinaryDumper(Int2) + _int4_dumper = Int4BinaryDumper(Int4) + _int8_dumper = Int8BinaryDumper(Int8) + _int_numeric_dumper = IntNumericBinaryDumper(IntNumeric) + + +class IntLoader(Loader): + def load(self, data: Buffer) -> int: + # it supports bytes directly + return int(data) + + +class Int2BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int2(data)[0] + + +class Int4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int4(data)[0] + + +class Int8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_int8(data)[0] + + +class OidBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> int: + return unpack_uint4(data)[0] + + +class FloatLoader(Loader): + def load(self, data: Buffer) -> float: + # it supports bytes directly + return float(data) + + +class Float4BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float4(data)[0] + + +class Float8BinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> float: + return unpack_float8(data)[0] + + +class NumericLoader(Loader): + def load(self, data: Buffer) -> Decimal: + if isinstance(data, memoryview): + data = bytes(data) + return Decimal(data.decode()) + + +DEC_DIGITS = 4 # decimal digits per Postgres "digit" +NUMERIC_POS = 0x0000 +NUMERIC_NEG = 0x4000 +NUMERIC_NAN = 0xC000 +NUMERIC_PINF = 0xD000 +NUMERIC_NINF = 0xF000 + +_decimal_special = { + NUMERIC_NAN: Decimal("NaN"), + NUMERIC_PINF: Decimal("Infinity"), + NUMERIC_NINF: Decimal("-Infinity"), +} + + +class _ContextMap(DefaultDict[int, Context]): + """ + Cache for decimal contexts to use when the precision requires it. + + Note: if the default context is used (prec=28) you can get an invalid + operation or a rounding to 0: + + - Decimal(1000).shift(24) = Decimal('1000000000000000000000000000') + - Decimal(1000).shift(25) = Decimal('0') + - Decimal(1000).shift(30) raises InvalidOperation + """ + + def __missing__(self, key: int) -> Context: + val = Context(prec=key) + self[key] = val + return val + + +_contexts = _ContextMap() +for i in range(DefaultContext.prec): + _contexts[i] = DefaultContext + +_unpack_numeric_head = cast( + Callable[[Buffer], Tuple[int, int, int, int]], + struct.Struct("!HhHH").unpack_from, +) +_pack_numeric_head = cast( + Callable[[int, int, int, int], bytes], + struct.Struct("!HhHH").pack, +) + + +class NumericBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Decimal: + ndigits, weight, sign, dscale = _unpack_numeric_head(data) + if sign == NUMERIC_POS or sign == NUMERIC_NEG: + val = 0 + for i in range(8, len(data), 2): + val = val * 10_000 + data[i] * 0x100 + data[i + 1] + + shift = dscale - (ndigits - weight - 1) * DEC_DIGITS + ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale] + return ( + Decimal(val if sign == NUMERIC_POS else -val) + .scaleb(-dscale, ctx) + .shift(shift, ctx) + ) + else: + try: + return _decimal_special[sign] + except KeyError: + raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None + + +NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0) +NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0) +NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0) + + +class DecimalBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Decimal) -> Buffer: + return dump_decimal_to_numeric_binary(obj) + + +class NumericDumper(DecimalDumper): + def dump(self, obj: Union[Decimal, int]) -> bytes: + if isinstance(obj, int): + return str(obj).encode() + else: + return super().dump(obj) + + +class NumericBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["numeric"].oid + + def dump(self, obj: Union[Decimal, int]) -> Buffer: + if isinstance(obj, int): + return dump_int_to_numeric_binary(obj) + else: + return dump_decimal_to_numeric_binary(obj) + + +def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]: + sign, digits, exp = obj.as_tuple() + if exp == "n" or exp == "N": # type: ignore[comparison-overlap] + return NUMERIC_NAN_BIN + elif exp == "F": # type: ignore[comparison-overlap] + return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN + + # Weights of py digits into a pg digit according to their positions. + # Starting with an index wi != 0 is equivalent to prepending 0's to + # the digits tuple, but without really changing it. + weights = (1000, 100, 10, 1) + wi = 0 + + ndigits = nzdigits = len(digits) + + # Find the last nonzero digit + while nzdigits > 0 and digits[nzdigits - 1] == 0: + nzdigits -= 1 + + if exp <= 0: + dscale = -exp + else: + dscale = 0 + # align the py digits to the pg digits if there's some py exponent + ndigits += exp % DEC_DIGITS + + if not nzdigits: + return _pack_numeric_head(0, 0, NUMERIC_POS, dscale) + + # Equivalent of 0-padding left to align the py digits to the pg digits + # but without changing the digits tuple. + mod = (ndigits - dscale) % DEC_DIGITS + if mod: + wi = DEC_DIGITS - mod + ndigits += wi + + tmp = nzdigits + wi + out = bytearray( + _pack_numeric_head( + tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits + (ndigits + exp) // DEC_DIGITS - 1, # weight + NUMERIC_NEG if sign else NUMERIC_POS, # sign + dscale, + ) + ) + + pgdigit = 0 + for i in range(nzdigits): + pgdigit += weights[wi] * digits[i] + wi += 1 + if wi >= DEC_DIGITS: + out += pack_uint2(pgdigit) + pgdigit = wi = 0 + + if pgdigit: + out += pack_uint2(pgdigit) + + return out + + +def dump_int_to_numeric_binary(obj: int) -> bytearray: + ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1 + out = bytearray(b"\x00\x00" * (ndigits + 4)) + if obj < 0: + sign = NUMERIC_NEG + obj = -obj + else: + sign = NUMERIC_POS + + out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0) + i = 8 + (ndigits - 1) * 2 + while obj: + rem = obj % 10_000 + obj //= 10_000 + out[i : i + 2] = pack_uint2(rem) + i -= 2 + + return out + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(int, IntDumper) + adapters.register_dumper(int, IntBinaryDumper) + adapters.register_dumper(float, FloatDumper) + adapters.register_dumper(float, FloatBinaryDumper) + adapters.register_dumper(Int2, Int2Dumper) + adapters.register_dumper(Int4, Int4Dumper) + adapters.register_dumper(Int8, Int8Dumper) + adapters.register_dumper(IntNumeric, IntNumericDumper) + adapters.register_dumper(Oid, OidDumper) + + # The binary dumper is currently some 30% slower, so default to text + # (see tests/scripts/testdec.py for a rough benchmark) + # Also, must be after IntNumericDumper + adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper) + adapters.register_dumper("decimal.Decimal", DecimalDumper) + + # Used only by oid, can take both int and Decimal as input + adapters.register_dumper(None, NumericBinaryDumper) + adapters.register_dumper(None, NumericDumper) + + adapters.register_dumper(Float4, Float4Dumper) + adapters.register_dumper(Float8, FloatDumper) + adapters.register_dumper(Int2, Int2BinaryDumper) + adapters.register_dumper(Int4, Int4BinaryDumper) + adapters.register_dumper(Int8, Int8BinaryDumper) + adapters.register_dumper(Oid, OidBinaryDumper) + adapters.register_dumper(Float4, Float4BinaryDumper) + adapters.register_dumper(Float8, FloatBinaryDumper) + adapters.register_loader("int2", IntLoader) + adapters.register_loader("int4", IntLoader) + adapters.register_loader("int8", IntLoader) + adapters.register_loader("oid", IntLoader) + adapters.register_loader("int2", Int2BinaryLoader) + adapters.register_loader("int4", Int4BinaryLoader) + adapters.register_loader("int8", Int8BinaryLoader) + adapters.register_loader("oid", OidBinaryLoader) + adapters.register_loader("float4", FloatLoader) + adapters.register_loader("float8", FloatLoader) + adapters.register_loader("float4", Float4BinaryLoader) + adapters.register_loader("float8", Float8BinaryLoader) + adapters.register_loader("numeric", NumericLoader) + adapters.register_loader("numeric", NumericBinaryLoader) diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py new file mode 100644 index 0000000..c418480 --- /dev/null +++ b/psycopg/psycopg/types/range.py @@ -0,0 +1,700 @@ +""" +Support for range types adaptation. +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple +from typing import cast +from decimal import Decimal +from datetime import date, datetime + +from .. import errors as e +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext, Buffer, Dumper, DumperKey +from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat +from .._struct import pack_len, unpack_len +from ..postgres import INVALID_OID, TEXT_OID +from .._typeinfo import RangeInfo as RangeInfo # exported here + +RANGE_EMPTY = 0x01 # range is empty +RANGE_LB_INC = 0x02 # lower bound is inclusive +RANGE_UB_INC = 0x04 # upper bound is inclusive +RANGE_LB_INF = 0x08 # lower bound is -infinity +RANGE_UB_INF = 0x10 # upper bound is +infinity + +_EMPTY_HEAD = bytes([RANGE_EMPTY]) + +T = TypeVar("T") + + +class Range(Generic[T]): + """Python representation for a PostgreSQL range type. + + :param lower: lower bound for the range. `!None` means unbound + :param upper: upper bound for the range. `!None` means unbound + :param bounds: one of the literal strings ``()``, ``[)``, ``(]``, ``[]``, + representing whether the lower or upper bounds are included + :param empty: if `!True`, the range is empty + + """ + + __slots__ = ("_lower", "_upper", "_bounds") + + def __init__( + self, + lower: Optional[T] = None, + upper: Optional[T] = None, + bounds: str = "[)", + empty: bool = False, + ): + if not empty: + if bounds not in ("[)", "(]", "()", "[]"): + raise ValueError("bound flags not valid: %r" % bounds) + + self._lower = lower + self._upper = upper + + # Make bounds consistent with infs + if lower is None and bounds[0] == "[": + bounds = "(" + bounds[1] + if upper is None and bounds[1] == "]": + bounds = bounds[0] + ")" + + self._bounds = bounds + else: + self._lower = self._upper = None + self._bounds = "" + + def __repr__(self) -> str: + if self._bounds: + args = f"{self._lower!r}, {self._upper!r}, {self._bounds!r}" + else: + args = "empty=True" + + return f"{self.__class__.__name__}({args})" + + def __str__(self) -> str: + if not self._bounds: + return "empty" + + items = [ + self._bounds[0], + str(self._lower), + ", ", + str(self._upper), + self._bounds[1], + ] + return "".join(items) + + @property + def lower(self) -> Optional[T]: + """The lower bound of the range. `!None` if empty or unbound.""" + return self._lower + + @property + def upper(self) -> Optional[T]: + """The upper bound of the range. `!None` if empty or unbound.""" + return self._upper + + @property + def bounds(self) -> str: + """The bounds string (two characters from '[', '(', ']', ')').""" + return self._bounds + + @property + def isempty(self) -> bool: + """`!True` if the range is empty.""" + return not self._bounds + + @property + def lower_inf(self) -> bool: + """`!True` if the range doesn't have a lower bound.""" + if not self._bounds: + return False + return self._lower is None + + @property + def upper_inf(self) -> bool: + """`!True` if the range doesn't have an upper bound.""" + if not self._bounds: + return False + return self._upper is None + + @property + def lower_inc(self) -> bool: + """`!True` if the lower bound is included in the range.""" + if not self._bounds or self._lower is None: + return False + return self._bounds[0] == "[" + + @property + def upper_inc(self) -> bool: + """`!True` if the upper bound is included in the range.""" + if not self._bounds or self._upper is None: + return False + return self._bounds[1] == "]" + + def __contains__(self, x: T) -> bool: + if not self._bounds: + return False + + if self._lower is not None: + if self._bounds[0] == "[": + # It doesn't seem that Python has an ABC for ordered types. + if x < self._lower: # type: ignore[operator] + return False + else: + if x <= self._lower: # type: ignore[operator] + return False + + if self._upper is not None: + if self._bounds[1] == "]": + if x > self._upper: # type: ignore[operator] + return False + else: + if x >= self._upper: # type: ignore[operator] + return False + + return True + + def __bool__(self) -> bool: + return bool(self._bounds) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Range): + return False + return ( + self._lower == other._lower + and self._upper == other._upper + and self._bounds == other._bounds + ) + + def __hash__(self) -> int: + return hash((self._lower, self._upper, self._bounds)) + + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Range): + return NotImplemented + for attr in ("_lower", "_upper", "_bounds"): + self_value = getattr(self, attr) + other_value = getattr(other, attr) + if self_value == other_value: + pass + elif self_value is None: + return True + elif other_value is None: + return False + else: + return cast(bool, self_value < other_value) + return False + + def __le__(self, other: Any) -> bool: + return self == other or self < other # type: ignore + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Range): + return other < self + else: + return NotImplemented + + def __ge__(self, other: Any) -> bool: + return self == other or self > other # type: ignore + + def __getstate__(self) -> Dict[str, Any]: + return { + slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot) + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + for slot, value in state.items(): + setattr(self, slot, value) + + +# Subclasses to specify a specific subtype. Usually not needed: only needed +# in binary copy, where switching to text is not an option. + + +class Int4Range(Range[int]): + pass + + +class Int8Range(Range[int]): + pass + + +class NumericRange(Range[Decimal]): + pass + + +class DateRange(Range[date]): + pass + + +class TimestampRange(Range[datetime]): + pass + + +class TimestamptzRange(Range[datetime]): + pass + + +class BaseRangeDumper(RecursiveDumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + self._adapt_format = PyFormat.from_pq(self.format) + + def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey: + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self.cls + + item = self._get_item(obj) + if item is not None: + sd = self._tx.get_dumper(item, self._adapt_format) + return (self.cls, sd.get_key(item, format)) + else: + return (self.cls,) + + def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.cls is not Range: + return self + + item = self._get_item(obj) + if item is None: + return RangeDumper(self.cls) + + dumper: BaseRangeDumper + if type(item) is int: + # postgres won't cast int4range -> int8range so we must use + # text format and unknown oid here + sd = self._tx.get_dumper(item, PyFormat.TEXT) + dumper = RangeDumper(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = INVALID_OID + return dumper + + sd = self._tx.get_dumper(item, format) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + if sd.oid == INVALID_OID and isinstance(item, str): + # Work around the normal mapping where text is dumped as unknown + dumper.oid = self._get_range_oid(TEXT_OID) + else: + dumper.oid = self._get_range_oid(sd.oid) + + return dumper + + def _get_item(self, obj: Range[Any]) -> Any: + """ + Return a member representative of the range + """ + rv = obj.lower + return rv if rv is not None else obj.upper + + def _get_range_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + """ + info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid) + return info.oid if info else INVALID_OID + + +class RangeDumper(BaseRangeDumper): + """ + Dumper for range types. + + The dumper can upgrade to one specific for a different range type. + """ + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_text(obj, dump) + + +def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if obj.isempty: + return b"empty" + + parts: List[Buffer] = [b"[" if obj.lower_inc else b"("] + + def dump_item(item: Any) -> Buffer: + ad = dump(item) + if not ad: + return b'""' + elif _re_needs_quotes.search(ad): + return b'"' + _re_esc.sub(rb"\1\1", ad) + b'"' + else: + return ad + + if obj.lower is not None: + parts.append(dump_item(obj.lower)) + + parts.append(b",") + + if obj.upper is not None: + parts.append(dump_item(obj.upper)) + + parts.append(b"]" if obj.upper_inc else b")") + + return b"".join(parts) + + +_re_needs_quotes = re.compile(rb'[",\\\s()\[\]]') +_re_esc = re.compile(rb"([\\\"])") + + +class RangeBinaryDumper(BaseRangeDumper): + + format = Format.BINARY + + def dump(self, obj: Range[Any]) -> Buffer: + item = self._get_item(obj) + if item is not None: + dump = self._tx.get_dumper(item, self._adapt_format).dump + else: + dump = fail_dump + + return dump_range_binary(obj, dump) + + +def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer: + if not obj: + return _EMPTY_HEAD + + out = bytearray([0]) # will replace the head later + + head = 0 + if obj.lower_inc: + head |= RANGE_LB_INC + if obj.upper_inc: + head |= RANGE_UB_INC + + if obj.lower is not None: + data = dump(obj.lower) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_LB_INF + + if obj.upper is not None: + data = dump(obj.upper) + out += pack_len(len(data)) + out += data + else: + head |= RANGE_UB_INF + + out[0] = head + return out + + +def fail_dump(obj: Any) -> Buffer: + raise e.InternalError("trying to dump a range element without information") + + +class BaseRangeLoader(RecursiveLoader, Generic[T]): + """Generic loader for a range. + + Subclasses must specify the oid of the subtype and the class to load. + """ + + subtype_oid: int + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load + + +class RangeLoader(BaseRangeLoader[T]): + def load(self, data: Buffer) -> Range[T]: + return load_range_text(data, self._load)[0] + + +def load_range_text( + data: Buffer, load: Callable[[Buffer], Any] +) -> Tuple[Range[Any], int]: + if data == b"empty": + return Range(empty=True), 5 + + m = _re_range.match(data) + if m is None: + raise e.DataError( + f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'" + ) + + lower = None + item = m.group(3) + if item is None: + item = m.group(2) + if item is not None: + lower = load(_re_undouble.sub(rb"\1", item)) + else: + lower = load(item) + + upper = None + item = m.group(5) + if item is None: + item = m.group(4) + if item is not None: + upper = load(_re_undouble.sub(rb"\1", item)) + else: + upper = load(item) + + bounds = (m.group(1) + m.group(6)).decode() + + return Range(lower, upper, bounds), m.end() + + +_re_range = re.compile( + rb""" + ( \(|\[ ) # lower bound flag + (?: # lower bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^",]+ ) # - or an unquoted string + )? # - or empty (not caught) + , + (?: # upper bound: + " ( (?: [^"] | "")* ) " # - a quoted string + | ( [^"\)\]]+ ) # - or an unquoted string + )? # - or empty (not caught) + ( \)|\] ) # upper bound flag + """, + re.VERBOSE, +) + +_re_undouble = re.compile(rb'(["\\])\1') + + +class RangeBinaryLoader(BaseRangeLoader[T]): + + format = Format.BINARY + + def load(self, data: Buffer) -> Range[T]: + return load_range_binary(data, self._load) + + +def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]: + head = data[0] + if head & RANGE_EMPTY: + return Range(empty=True) + + lb = "[" if head & RANGE_LB_INC else "(" + ub = "]" if head & RANGE_UB_INC else ")" + + pos = 1 # after the head + if head & RANGE_LB_INF: + min = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + min = load(data[pos : pos + length]) + pos += length + + if head & RANGE_UB_INF: + max = None + else: + length = unpack_len(data, pos)[0] + pos += 4 + max = load(data[pos : pos + length]) + pos += length + + return Range(min, max, lb + ub) + + +def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None: + """Register the adapters to load and dump a range type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the requested range available?") + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[RangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (RangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[RangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (RangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + +# Text dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeDumper(RangeDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeDumper(RangeDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeDumper(RangeDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeDumper(RangeDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeDumper(RangeDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeDumper(RangeDumper): + oid = postgres.types["tstzrange"].oid + + +# Binary dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int4range"].oid + + +class Int8RangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["int8range"].oid + + +class NumericRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["numrange"].oid + + +class DateRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["daterange"].oid + + +class TimestampRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tsrange"].oid + + +class TimestamptzRangeBinaryDumper(RangeBinaryDumper): + oid = postgres.types["tstzrange"].oid + + +# Text loaders for builtin range types + + +class Int4RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeLoader(RangeLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeLoader(RangeLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeLoader(RangeLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeLoader(RangeLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +# Binary loaders for builtin range types + + +class Int4RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int4"].oid + + +class Int8RangeBinaryLoader(RangeBinaryLoader[int]): + subtype_oid = postgres.types["int8"].oid + + +class NumericRangeBinaryLoader(RangeBinaryLoader[Decimal]): + subtype_oid = postgres.types["numeric"].oid + + +class DateRangeBinaryLoader(RangeBinaryLoader[date]): + subtype_oid = postgres.types["date"].oid + + +class TimestampRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamp"].oid + + +class TimestampTZRangeBinaryLoader(RangeBinaryLoader[datetime]): + subtype_oid = postgres.types["timestamptz"].oid + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper(Range, RangeBinaryDumper) + adapters.register_dumper(Range, RangeDumper) + adapters.register_dumper(Int4Range, Int4RangeDumper) + adapters.register_dumper(Int8Range, Int8RangeDumper) + adapters.register_dumper(NumericRange, NumericRangeDumper) + adapters.register_dumper(DateRange, DateRangeDumper) + adapters.register_dumper(TimestampRange, TimestampRangeDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeDumper) + adapters.register_dumper(Int4Range, Int4RangeBinaryDumper) + adapters.register_dumper(Int8Range, Int8RangeBinaryDumper) + adapters.register_dumper(NumericRange, NumericRangeBinaryDumper) + adapters.register_dumper(DateRange, DateRangeBinaryDumper) + adapters.register_dumper(TimestampRange, TimestampRangeBinaryDumper) + adapters.register_dumper(TimestamptzRange, TimestamptzRangeBinaryDumper) + adapters.register_loader("int4range", Int4RangeLoader) + adapters.register_loader("int8range", Int8RangeLoader) + adapters.register_loader("numrange", NumericRangeLoader) + adapters.register_loader("daterange", DateRangeLoader) + adapters.register_loader("tsrange", TimestampRangeLoader) + adapters.register_loader("tstzrange", TimestampTZRangeLoader) + adapters.register_loader("int4range", Int4RangeBinaryLoader) + adapters.register_loader("int8range", Int8RangeBinaryLoader) + adapters.register_loader("numrange", NumericRangeBinaryLoader) + adapters.register_loader("daterange", DateRangeBinaryLoader) + adapters.register_loader("tsrange", TimestampRangeBinaryLoader) + adapters.register_loader("tstzrange", TimestampTZRangeBinaryLoader) diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py new file mode 100644 index 0000000..e99f256 --- /dev/null +++ b/psycopg/psycopg/types/shapely.py @@ -0,0 +1,75 @@ +""" +Adapters for PostGIS geometries +""" + +from typing import Optional + +from .. import postgres +from ..abc import AdaptContext, Buffer +from ..adapt import Dumper, Loader +from ..pq import Format +from .._typeinfo import TypeInfo + + +try: + from shapely.wkb import loads, dumps + from shapely.geometry.base import BaseGeometry + +except ImportError: + raise ImportError( + "The module psycopg.types.shapely requires the package 'Shapely'" + " to be installed" + ) + + +class GeometryBinaryLoader(Loader): + format = Format.BINARY + + def load(self, data: Buffer) -> "BaseGeometry": + if not isinstance(data, bytes): + data = bytes(data) + return loads(data) + + +class GeometryLoader(Loader): + def load(self, data: Buffer) -> "BaseGeometry": + # it's a hex string in binary + if isinstance(data, memoryview): + data = bytes(data) + return loads(data.decode(), hex=True) + + +class BaseGeometryBinaryDumper(Dumper): + format = Format.BINARY + + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj) # type: ignore + + +class BaseGeometryDumper(Dumper): + def dump(self, obj: "BaseGeometry") -> bytes: + return dumps(obj, hex=True).encode() # type: ignore + + +def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None: + """Register Shapely dumper and loaders.""" + + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError("no info passed. Is the 'postgis' extension loaded?") + + info.register(context) + adapters = context.adapters if context else postgres.adapters + + class GeometryDumper(BaseGeometryDumper): + oid = info.oid + + class GeometryBinaryDumper(BaseGeometryBinaryDumper): + oid = info.oid + + adapters.register_loader(info.oid, GeometryBinaryLoader) + adapters.register_loader(info.oid, GeometryLoader) + # Default binary dump + adapters.register_dumper(BaseGeometry, GeometryDumper) + adapters.register_dumper(BaseGeometry, GeometryBinaryDumper) diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py new file mode 100644 index 0000000..cd5360d --- /dev/null +++ b/psycopg/psycopg/types/string.py @@ -0,0 +1,239 @@ +""" +Adapters for textual types. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Optional, Union, TYPE_CHECKING + +from .. import postgres +from ..pq import Format, Escaping +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader +from ..errors import DataError +from .._encodings import conn_encoding + +if TYPE_CHECKING: + from ..pq.abc import Escaping as EscapingProto + + +class _BaseStrDumper(Dumper): + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "utf-8" + + +class _StrBinaryDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in binary format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + format = Format.BINARY + + def dump(self, obj: str) -> bytes: + # the server will raise DataError subclass if the string contains 0x00 + return obj.encode(self._encoding) + + +class _StrDumper(_BaseStrDumper): + """ + Base class to dump a Python strings to a Postgres text type, in text format. + + Subclasses shall specify the oids of real types (text, varchar, name...). + """ + + def dump(self, obj: str) -> bytes: + if "\x00" in obj: + raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes") + else: + return obj.encode(self._encoding) + + +# The next are concrete dumpers, each one specifying the oid they dump to. + + +class StrBinaryDumper(_StrBinaryDumper): + + oid = postgres.types["text"].oid + + +class StrBinaryDumperVarchar(_StrBinaryDumper): + + oid = postgres.types["varchar"].oid + + +class StrBinaryDumperName(_StrBinaryDumper): + + oid = postgres.types["name"].oid + + +class StrDumper(_StrDumper): + """ + Dumper for strings in text format to the text oid. + + Note that this dumper is not used by default because the type is too strict + and PostgreSQL would require an explicit casts to everything that is not a + text field. However it is useful where the unknown oid is ambiguous and the + text oid is required, for instance with variadic functions. + """ + + oid = postgres.types["text"].oid + + +class StrDumperVarchar(_StrDumper): + + oid = postgres.types["varchar"].oid + + +class StrDumperName(_StrDumper): + + oid = postgres.types["name"].oid + + +class StrDumperUnknown(_StrDumper): + """ + Dumper for strings in text format to the unknown oid. + + This dumper is the default dumper for strings and allows to use Python + strings to represent almost every data type. In a few places, however, the + unknown oid is not accepted (for instance in variadic functions such as + 'concat()'). In that case either a cast on the placeholder ('%s::text') or + the StrTextDumper should be used. + """ + + pass + + +class TextLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "" + + def load(self, data: Buffer) -> Union[bytes, str]: + if self._encoding: + if isinstance(data, memoryview): + data = bytes(data) + return data.decode(self._encoding) + else: + # return bytes for SQL_ASCII db + if not isinstance(data, bytes): + data = bytes(data) + return data + + +class TextBinaryLoader(TextLoader): + + format = Format.BINARY + + +class BytesDumper(Dumper): + + oid = postgres.types["bytea"].oid + _qprefix = b"" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self._esc = Escaping(self.connection.pgconn if self.connection else None) + + def dump(self, obj: Buffer) -> Buffer: + return self._esc.escape_bytea(obj) + + def quote(self, obj: Buffer) -> bytes: + escaped = self.dump(obj) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self.connection: + if not self._qprefix: + scs = self.connection.pgconn.parameter_status( + b"standard_conforming_strings" + ) + self._qprefix = b"'" if scs == b"on" else b" E'" + + return self._qprefix + escaped + b"'" + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + rv: bytes = b" E'" + escaped + b"'" + if self._esc.escape_bytea(b"\x00") == b"\\000": + rv = rv.replace(b"\\", b"\\\\") + return rv + + +class BytesBinaryDumper(Dumper): + + format = Format.BINARY + oid = postgres.types["bytea"].oid + + def dump(self, obj: Buffer) -> Buffer: + return obj + + +class ByteaLoader(Loader): + + _escaping: "EscapingProto" + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + if not hasattr(self.__class__, "_escaping"): + self.__class__._escaping = Escaping() + + def load(self, data: Buffer) -> bytes: + return self._escaping.unescape_bytea(data) + + +class ByteaBinaryLoader(Loader): + + format = Format.BINARY + + def load(self, data: Buffer) -> Buffer: + return data + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + + # NOTE: the order the dumpers are registered is relevant. The last one + # registered becomes the default for each type. Usually, binary is the + # default dumper. For text we use the text dumper as default because it + # plays the role of unknown, and it can be cast automatically to other + # types. However, before that, we register dumper with 'text', 'varchar', + # 'name' oids, which will be used when a text dumper is looked up by oid. + adapters.register_dumper(str, StrBinaryDumperName) + adapters.register_dumper(str, StrBinaryDumperVarchar) + adapters.register_dumper(str, StrBinaryDumper) + adapters.register_dumper(str, StrDumperName) + adapters.register_dumper(str, StrDumperVarchar) + adapters.register_dumper(str, StrDumper) + adapters.register_dumper(str, StrDumperUnknown) + + adapters.register_loader(postgres.INVALID_OID, TextLoader) + adapters.register_loader("bpchar", TextLoader) + adapters.register_loader("name", TextLoader) + adapters.register_loader("text", TextLoader) + adapters.register_loader("varchar", TextLoader) + adapters.register_loader('"char"', TextLoader) + adapters.register_loader("bpchar", TextBinaryLoader) + adapters.register_loader("name", TextBinaryLoader) + adapters.register_loader("text", TextBinaryLoader) + adapters.register_loader("varchar", TextBinaryLoader) + adapters.register_loader('"char"', TextBinaryLoader) + + adapters.register_dumper(bytes, BytesDumper) + adapters.register_dumper(bytearray, BytesDumper) + adapters.register_dumper(memoryview, BytesDumper) + adapters.register_dumper(bytes, BytesBinaryDumper) + adapters.register_dumper(bytearray, BytesBinaryDumper) + adapters.register_dumper(memoryview, BytesBinaryDumper) + + adapters.register_loader("bytea", ByteaLoader) + adapters.register_loader(postgres.INVALID_OID, ByteaBinaryLoader) + adapters.register_loader("bytea", ByteaBinaryLoader) diff --git a/psycopg/psycopg/types/uuid.py b/psycopg/psycopg/types/uuid.py new file mode 100644 index 0000000..f92354c --- /dev/null +++ b/psycopg/psycopg/types/uuid.py @@ -0,0 +1,65 @@ +""" +Adapters for the UUID type. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Callable, Optional, TYPE_CHECKING + +from .. import postgres +from ..pq import Format +from ..abc import AdaptContext +from ..adapt import Buffer, Dumper, Loader + +if TYPE_CHECKING: + import uuid + +# Importing the uuid module is slow, so import it only on request. +UUID: Callable[..., "uuid.UUID"] = None # type: ignore[assignment] + + +class UUIDDumper(Dumper): + + oid = postgres.types["uuid"].oid + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.hex.encode() + + +class UUIDBinaryDumper(UUIDDumper): + + format = Format.BINARY + + def dump(self, obj: "uuid.UUID") -> bytes: + return obj.bytes + + +class UUIDLoader(Loader): + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + global UUID + if UUID is None: + from uuid import UUID + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(data.decode()) + + +class UUIDBinaryLoader(UUIDLoader): + + format = Format.BINARY + + def load(self, data: Buffer) -> "uuid.UUID": + if isinstance(data, memoryview): + data = bytes(data) + return UUID(bytes=data) + + +def register_default_adapters(context: AdaptContext) -> None: + adapters = context.adapters + adapters.register_dumper("uuid.UUID", UUIDDumper) + adapters.register_dumper("uuid.UUID", UUIDBinaryDumper) + adapters.register_loader("uuid", UUIDLoader) + adapters.register_loader("uuid", UUIDBinaryLoader) diff --git a/psycopg/psycopg/version.py b/psycopg/psycopg/version.py new file mode 100644 index 0000000..a98bc35 --- /dev/null +++ b/psycopg/psycopg/version.py @@ -0,0 +1,14 @@ +""" +psycopg distribution version file. +""" + +# Copyright (C) 2020 The Psycopg Team + +# Use a versioning scheme as defined in +# https://www.python.org/dev/peps/pep-0440/ + +# STOP AND READ! if you change: +__version__ = "3.1.7" +# also change: +# - `docs/news.rst` to declare this as the current version or an unreleased one +# - `psycopg_c/psycopg_c/version.py` to the same version. diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py new file mode 100644 index 0000000..7abfc58 --- /dev/null +++ b/psycopg/psycopg/waiting.py @@ -0,0 +1,331 @@ +""" +Code concerned with waiting in different contexts (blocking, async, etc). + +These functions are designed to consume the generators returned by the +`generators` module function and to return their final value. + +""" + +# Copyright (C) 2020 The Psycopg Team + + +import os +import select +import selectors +from typing import Dict, Optional +from asyncio import get_event_loop, wait_for, Event, TimeoutError +from selectors import DefaultSelector + +from . import errors as e +from .abc import RV, PQGen, PQGenConn, WaitFunc +from ._enums import Wait as Wait, Ready as Ready # re-exported +from ._cmodule import _psycopg + +WAIT_R = Wait.R +WAIT_W = Wait.W +WAIT_RW = Wait.RW +READY_R = Ready.R +READY_W = Ready.W +READY_RW = Ready.RW + + +def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Consume `!gen`, scheduling `fileno` for completion when it is reported to + block. Once ready again send the ready state back to `!gen`. + """ + try: + s = next(gen) + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = None + while not rlist: + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + # note: this line should require a cast, but mypy doesn't complain + ready: Ready = rlist[0][1] + assert s & ready + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Wait for a connection generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :type timeout: float + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + try: + fileno, s = next(gen) + if not timeout: + timeout = None + with DefaultSelector() as sel: + while True: + sel.register(fileno, s) + rlist = sel.select(timeout=timeout) + sel.unregister(fileno) + if not rlist: + raise e.ConnectionTimeout("connection timeout expired") + ready: Ready = rlist[0][1] # type: ignore[assignment] + fileno, s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_async(gen: PQGen[RV], fileno: int) -> RV: + """ + Coroutine waiting for a generator to complete. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but exposing an `asyncio` interface. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready |= state # type: ignore[assignment] + ev.set() + + try: + s = next(gen) + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await ev.wait() + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Coroutine waiting for a connection generator to complete. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. If zero or None, wait indefinitely. + :return: whatever `!gen` returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready = state + ev.set() + + try: + fileno, s = next(gen) + if not timeout: + timeout = None + while True: + reader = s & WAIT_R + writer = s & WAIT_W + if not reader and not writer: + raise e.InternalError(f"bad poll status: {s}") + ev.clear() + ready = 0 # type: ignore[assignment] + if reader: + loop.add_reader(fileno, wakeup, READY_R) + if writer: + loop.add_writer(fileno, wakeup, READY_W) + try: + await wait_for(ev.wait(), timeout) + finally: + if reader: + loop.remove_reader(fileno) + if writer: + loop.remove_writer(fileno) + fileno, s = gen.send(ready) + + except TimeoutError: + raise e.ConnectionTimeout("connection timeout expired") + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +# Specialised implementation of wait functions. + + +def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using select where supported. + """ + try: + s = next(gen) + + empty = () + fnlist = (fileno,) + while True: + rl, wl, xl = select.select( + fnlist if s & WAIT_R else empty, + fnlist if s & WAIT_W else empty, + fnlist, + timeout, + ) + ready = 0 + if rl: + ready = READY_R + if wl: + ready |= READY_W + if not ready: + continue + # assert s & ready + s = gen.send(ready) # type: ignore + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +poll_evmasks: Dict[Wait, int] + +if hasattr(selectors, "EpollSelector"): + poll_evmasks = { + WAIT_R: select.EPOLLONESHOT | select.EPOLLIN, + WAIT_W: select.EPOLLONESHOT | select.EPOLLOUT, + WAIT_RW: select.EPOLLONESHOT | select.EPOLLIN | select.EPOLLOUT, + } +else: + poll_evmasks = {} + + +def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: + """ + Wait for a generator using epoll where supported. + + Parameters are like for `wait()`. If it is detected that the best selector + strategy is `epoll` then this function will be used instead of `wait`. + + See also: https://linux.die.net/man/2/epoll_ctl + """ + try: + s = next(gen) + + if timeout is None or timeout < 0: + timeout = 0 + else: + timeout = int(timeout * 1000.0) + + with select.epoll() as epoll: + evmask = poll_evmasks[s] + epoll.register(fileno, evmask) + while True: + fileevs = None + while not fileevs: + fileevs = epoll.poll(timeout) + ev = fileevs[0][1] + ready = 0 + if ev & ~select.EPOLLOUT: + ready = READY_R + if ev & ~select.EPOLLIN: + ready |= READY_W + # assert s & ready + s = gen.send(ready) + evmask = poll_evmasks[s] + epoll.modify(fileno, evmask) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +if _psycopg: + wait_c = _psycopg.wait_c + + +# Choose the best wait strategy for the platform. +# +# the selectors objects have a generic interface but come with some overhead, +# so we also offer more finely tuned implementations. + +wait: WaitFunc + +# Allow the user to choose a specific function for testing +if "PSYCOPG_WAIT_FUNC" in os.environ: + fname = os.environ["PSYCOPG_WAIT_FUNC"] + if not fname.startswith("wait_") or fname not in globals(): + raise ImportError( + "PSYCOPG_WAIT_FUNC should be the name of an available wait function;" + f" got {fname!r}" + ) + wait = globals()[fname] + +elif _psycopg: + wait = wait_c + +elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None): + # On Windows, SelectSelector should be the default. + wait = wait_select + +elif selectors.DefaultSelector is getattr(selectors, "EpollSelector", None): + # NOTE: select seems more performing than epoll. It is admittedly unlikely + # that a platform has epoll but not select, so maybe we could kill + # wait_epoll altogether(). More testing to do. + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_epoll + +elif selectors.DefaultSelector is getattr(selectors, "KqueueSelector", None): + # wait_select is faster than wait_selector, probably because of less overhead + wait = wait_select if hasattr(selectors, "SelectSelector") else wait_selector + +else: + wait = wait_selector diff --git a/psycopg/pyproject.toml b/psycopg/pyproject.toml new file mode 100644 index 0000000..21e410c --- /dev/null +++ b/psycopg/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=49.2.0", "wheel>=0.37"] +build-backend = "setuptools.build_meta" diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg new file mode 100644 index 0000000..fdcb612 --- /dev/null +++ b/psycopg/setup.cfg @@ -0,0 +1,47 @@ +[metadata] +name = psycopg +description = PostgreSQL database adapter for Python +url = https://psycopg.org/psycopg3/ +author = Daniele Varrazzo +author_email = daniele.varrazzo@gmail.com +license = GNU Lesser General Public License v3 (LGPLv3) + +project_urls = + Homepage = https://psycopg.org/ + Code = https://github.com/psycopg/psycopg + Issue Tracker = https://github.com/psycopg/psycopg/issues + Download = https://pypi.org/project/psycopg/ + +classifiers = + Development Status :: 5 - Production/Stable + Intended Audience :: Developers + License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3) + Operating System :: MacOS :: MacOS X + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Database + Topic :: Database :: Front-Ends + Topic :: Software Development + Topic :: Software Development :: Libraries :: Python Modules + +long_description = file: README.rst +long_description_content_type = text/x-rst +license_files = LICENSE.txt + +[options] +python_requires = >= 3.7 +packages = find: +zip_safe = False +install_requires = + backports.zoneinfo >= 0.2.0; python_version < "3.9" + typing-extensions >= 4.1 + tzdata; sys_platform == "win32" + +[options.package_data] +psycopg = py.typed diff --git a/psycopg/setup.py b/psycopg/setup.py new file mode 100644 index 0000000..90d4380 --- /dev/null +++ b/psycopg/setup.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +PostgreSQL database adapter for Python - pure Python package +""" + +# Copyright (C) 2020 The Psycopg Team + +import os +from setuptools import setup + +# Move to the directory of setup.py: executing this file from another location +# (e.g. from the project root) will fail +here = os.path.abspath(os.path.dirname(__file__)) +if os.path.abspath(os.getcwd()) != here: + os.chdir(here) + +# Only for release 3.1.7. Not building binary packages because Scaleway +# has no runner available, but psycopg-binary 3.1.6 should work as well +# as the only change is in rows.py. +version = "3.1.7" +ext_versions = ">= 3.1.6, <= 3.1.7" + +extras_require = { + # Install the C extension module (requires dev tools) + "c": [ + f"psycopg-c {ext_versions}", + ], + # Install the stand-alone C extension module + "binary": [ + f"psycopg-binary {ext_versions}", + ], + # Install the connection pool + "pool": [ + "psycopg-pool", + ], + # Requirements to run the test suite + "test": [ + "mypy >= 0.990", + "pproxy >= 2.7", + "pytest >= 6.2.5", + "pytest-asyncio >= 0.17", + "pytest-cov >= 3.0", + "pytest-randomly >= 3.10", + ], + # Requirements needed for development + "dev": [ + "black >= 22.3.0", + "dnspython >= 2.1", + "flake8 >= 4.0", + "mypy >= 0.990", + "types-setuptools >= 57.4", + "wheel >= 0.37", + ], + # Requirements needed to build the documentation + "docs": [ + "Sphinx >= 5.0", + "furo == 2022.6.21", + "sphinx-autobuild >= 2021.3.14", + "sphinx-autodoc-typehints >= 1.12", + ], +} + +setup( + version=version, + extras_require=extras_require, +) |